HyperCLOVAX-SEED-Think-4B / modeling_hyperclovax_seed_vision_encoder.py
bigshanedogg's picture
Upload folder using huggingface_hub
0c1d6f8 verified
# coding=utf-8
# Copyright 2026 NAVER Cloud Corp. and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HyperCLOVAX-SEED Vision Encoder model.
A spatio-temporal vision transformer using the Qwen2.5-VL ViT architecture
(window attention, 3D patch embedding, 2D RoPE), trained with SigLIP-style
sigmoid contrastive loss.
Code-level modifications over the base Qwen2.5-VL ViT:
- transformers 5.x compatibility: RotaryEmbedding recomputes inv_freq on-the-fly
to handle no_init_weights() zeroing (persistent=False register_buffer)
- Float16 numerical stability: autocast paths in PatchMerger and the last
transformer block's MLP
- disable_merger option: skips PatchMerger and returns raw patch features
with window index for external merging
Acknowledgements:
- Architecture adapted from Qwen2.5-VL ViT
(https://github.com/QwenLM/Qwen2.5-VL), Apache-2.0 License.
- Training objective based on SigLIP
(https://github.com/google-research/big_vision), Apache-2.0 License.
"""
from collections.abc import Callable
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
try:
from transformers.modeling_layers import GradientCheckpointingLayer
except ImportError:
class GradientCheckpointingLayer(nn.Module): # transformers < 4.46
pass
try:
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
except ImportError:
ALL_ATTENTION_FUNCTIONS = {} # transformers < 4.46
from .configuration_hyperclovax_seed_vision_encoder import HyperCLOVAXSeedVisionEncoderConfig
class HyperCLOVAXSeedVisionRMSNorm(nn.Module):
"""RMS normalisation layer."""
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self) -> str:
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary position embeddings to query and key tensors."""
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
hidden_states: (batch, num_key_value_heads, seqlen, head_dim)
-> (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
"""Eager (non-fused) scaled dot-product attention, used as fallback."""
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class HyperCLOVAXSeedVisionMLP(nn.Module):
"""SwiGLU MLP used inside each vision transformer block."""
def __init__(self, config: HyperCLOVAXSeedVisionEncoderConfig, bias: bool = False):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
class HyperCLOVAXSeedVisionPatchEmbed(nn.Module):
"""3D patch embedding for spatio-temporal inputs via Conv3d."""
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
embed_dim: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
hidden_states = hidden_states.view(
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
)
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
return hidden_states
class HyperCLOVAXSeedVisionRotaryEmbedding(nn.Module):
"""2D rotary position embedding for vision patches.
Recomputes ``inv_freq`` in ``forward`` to be robust against
``no_init_weights()`` zeroing in transformers 5.x (``persistent=False``).
"""
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
# Recompute inv_freq on the fly: in transformers 5.x, no_init_weights() zeros out
# register_buffer values, and persistent=False means they aren't restored from checkpoint.
inv_freq = 1.0 / (self.theta ** (
torch.arange(0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) / self.dim
))
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=inv_freq.dtype)
freqs = torch.outer(seq, inv_freq)
return freqs
class HyperCLOVAXSeedVisionPatchMerger(nn.Module):
"""MLP that merges spatially-grouped patches and projects to the output hidden size."""
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.ln_q = HyperCLOVAXSeedVisionRMSNorm(context_dim, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.GELU(),
nn.Linear(self.hidden_size, dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.mlp[0].weight.dtype == torch.float16:
with torch.amp.autocast(device_type=x.device.type, dtype=torch.float32):
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
x = x.to(self.mlp[0].weight.dtype)
else:
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
return x
class HyperCLOVAXSeedVisionAttention(nn.Module):
"""Multi-head self-attention with 2D RoPE, supporting flash-attention and SDPA."""
def __init__(self, config: HyperCLOVAXSeedVisionEncoderConfig) -> None:
super().__init__()
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
self.num_key_value_groups = 1 # needed for eager attention
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
self.proj = nn.Linear(self.dim, self.dim)
self.scaling = self.head_dim**-0.5
self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if self.config._attn_implementation == "flash_attention_2":
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
cu_seq_lens_q=cu_seqlens,
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
else:
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
attention_interface(
self,
q,
k,
v,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=False,
**kwargs,
)[0]
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
class HyperCLOVAXSeedVisionBlock(GradientCheckpointingLayer):
"""Transformer block with window or full attention and fp16-safe MLP."""
def __init__(
self,
config: HyperCLOVAXSeedVisionEncoderConfig,
is_fullatt: bool = False,
is_last: bool = False,
) -> None:
super().__init__()
self.norm1 = HyperCLOVAXSeedVisionRMSNorm(config.hidden_size, eps=1e-6)
self.norm2 = HyperCLOVAXSeedVisionRMSNorm(config.hidden_size, eps=1e-6)
self.attn = HyperCLOVAXSeedVisionAttention(config=config)
self.mlp = HyperCLOVAXSeedVisionMLP(config, bias=True)
self.is_fullatt = is_fullatt
self.is_last = is_last
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
**kwargs,
)
# fp16 full-attention blocks and the last block accumulate rounding error
# in the MLP; promote to float32 for numerical stability.
if (
(not self.is_fullatt and not self.is_last)
or hidden_states.dtype != torch.float16
):
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
else:
org_type = hidden_states.dtype
with torch.amp.autocast(device_type=hidden_states.device.type, dtype=torch.float32):
mlp_out = self.mlp(self.norm2(hidden_states))
if self.is_last:
hidden_states = (hidden_states + mlp_out).to(torch.float32)
else:
hidden_states = (hidden_states + mlp_out).to(org_type)
return hidden_states
class HyperCLOVAXSeedVisionEncoder(PreTrainedModel):
"""HyperCLOVAX-SEED Vision Encoder.
A spatio-temporal vision transformer that encodes images and videos into
token sequences suitable for a causal language model,
using window-based and full attention blocks.
The encoder outputs merged patch embeddings of shape
``(total_merged_patches, out_hidden_size)`` where
``total_merged_patches = sum(t * h * w / spatial_merge_size^2 for each input)``.
"""
config_class = HyperCLOVAXSeedVisionEncoderConfig
_no_split_modules = ["HyperCLOVAXSeedVisionBlock"]
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config: HyperCLOVAXSeedVisionEncoderConfig, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
self.spatial_merge_size = config.spatial_merge_size
self.patch_size = config.patch_size
self.fullatt_block_indexes = config.fullatt_block_indexes
self.window_size = config.window_size
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
self.disable_merger = config.disable_merger
self.patch_embed = HyperCLOVAXSeedVisionPatchEmbed(
patch_size=config.patch_size,
temporal_patch_size=config.temporal_patch_size,
in_channels=config.in_channels,
embed_dim=config.hidden_size,
)
head_dim = config.hidden_size // config.num_heads
self.rotary_pos_emb = HyperCLOVAXSeedVisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([
HyperCLOVAXSeedVisionBlock(
config,
is_fullatt=(_block_idx in config.fullatt_block_indexes),
is_last=(_block_idx == config.depth - 1),
)
for _block_idx in range(config.depth)
])
self.merger = None
if not self.disable_merger:
self.merger = HyperCLOVAXSeedVisionPatchMerger(
dim=config.out_hidden_size,
context_dim=config.hidden_size,
spatial_merge_size=config.spatial_merge_size,
)
self.gradient_checkpointing = False
self.post_init()
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
"""Compute 2D rotary position embeddings for all patches in the batch."""
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def get_window_index(self, grid_thw: torch.Tensor) -> tuple[torch.Tensor, list]:
"""Build a flat index that reorders tokens into non-overlapping windows.
Returns:
window_index: permutation indices to gather tokens in window order
cu_window_seqlens: cumulative window sequence lengths for varlen attention
"""
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // self.spatial_merge_size,
grid_w // self.spatial_merge_size,
)
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
def forward(
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""
Args:
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
Flattened patch pixels (output of patch embedding pipeline before this call).
In practice this is the raw pixel tensor passed to `patch_embed` internally.
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
Temporal, height and width grid dimensions for each input item.
Returns:
`torch.Tensor` of shape `(total_merged_patches, out_hidden_size)`.
"""
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for layer_num, blk in enumerate(self.blocks):
cu_seqlens_now = cu_seqlens if layer_num in self.fullatt_block_indexes else cu_window_seqlens
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
blk,
hidden_states,
cu_seqlens_now,
None, # rotary_pos_emb (unused; position_embeddings used instead)
position_embeddings,
)
else:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
position_embeddings=position_embeddings,
**kwargs,
)
if self.merger is not None:
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states
else:
# window_index to rearrange patches
return hidden_states, window_index
AutoModel.register(HyperCLOVAXSeedVisionEncoderConfig, HyperCLOVAXSeedVisionEncoder)
__all__ = ["HyperCLOVAXSeedVisionEncoder"]