# coding=utf-8 # Copyright 2026 NAVER Cloud Corp. and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """HyperCLOVAX-SEED Vision Encoder model. A spatio-temporal vision transformer using the Qwen2.5-VL ViT architecture (window attention, 3D patch embedding, 2D RoPE), trained with SigLIP-style sigmoid contrastive loss. Code-level modifications over the base Qwen2.5-VL ViT: - transformers 5.x compatibility: RotaryEmbedding recomputes inv_freq on-the-fly to handle no_init_weights() zeroing (persistent=False register_buffer) - Float16 numerical stability: autocast paths in PatchMerger and the last transformer block's MLP - disable_merger option: skips PatchMerger and returns raw patch features with window index for external merging Acknowledgements: - Architecture adapted from Qwen2.5-VL ViT (https://github.com/QwenLM/Qwen2.5-VL), Apache-2.0 License. - Training objective based on SigLIP (https://github.com/google-research/big_vision), Apache-2.0 License. """ from collections.abc import Callable from typing import Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel try: from transformers.modeling_layers import GradientCheckpointingLayer except ImportError: class GradientCheckpointingLayer(nn.Module): # transformers < 4.46 pass try: from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS except ImportError: ALL_ATTENTION_FUNCTIONS = {} # transformers < 4.46 from .configuration_hyperclovax_seed_vision_encoder import HyperCLOVAXSeedVisionEncoderConfig class HyperCLOVAXSeedVisionRMSNorm(nn.Module): """RMS normalisation layer.""" def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self) -> str: return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" def rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb_vision( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Apply rotary position embeddings to query and key tensors.""" orig_q_dtype = q.dtype orig_k_dtype = k.dtype q, k = q.float(), k.float() cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) q_embed = q_embed.to(orig_q_dtype) k_embed = k_embed.to(orig_k_dtype) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). hidden_states: (batch, num_key_value_heads, seqlen, head_dim) -> (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): """Eager (non-fused) scaled dot-product attention, used as fallback.""" key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class HyperCLOVAXSeedVisionMLP(nn.Module): """SwiGLU MLP used inside each vision transformer block.""" def __init__(self, config: HyperCLOVAXSeedVisionEncoderConfig, bias: bool = False): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) class HyperCLOVAXSeedVisionPatchEmbed(nn.Module): """3D patch embedding for spatio-temporal inputs via Conv3d.""" def __init__( self, patch_size: int = 14, temporal_patch_size: int = 2, in_channels: int = 3, embed_dim: int = 1152, ) -> None: super().__init__() self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.in_channels = in_channels self.embed_dim = embed_dim kernel_size = [temporal_patch_size, patch_size, patch_size] self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: target_dtype = self.proj.weight.dtype hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size ) hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) return hidden_states class HyperCLOVAXSeedVisionRotaryEmbedding(nn.Module): """2D rotary position embedding for vision patches. Recomputes ``inv_freq`` in ``forward`` to be robust against ``no_init_weights()`` zeroing in transformers 5.x (``persistent=False``). """ inv_freq: torch.Tensor # fix linting for `register_buffer` def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() self.dim = dim self.theta = theta inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: # Recompute inv_freq on the fly: in transformers 5.x, no_init_weights() zeros out # register_buffer values, and persistent=False means they aren't restored from checkpoint. inv_freq = 1.0 / (self.theta ** ( torch.arange(0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) / self.dim )) seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=inv_freq.dtype) freqs = torch.outer(seq, inv_freq) return freqs class HyperCLOVAXSeedVisionPatchMerger(nn.Module): """MLP that merges spatially-grouped patches and projects to the output hidden size.""" def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) self.ln_q = HyperCLOVAXSeedVisionRMSNorm(context_dim, eps=1e-6) self.mlp = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.GELU(), nn.Linear(self.hidden_size, dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.mlp[0].weight.dtype == torch.float16: with torch.amp.autocast(device_type=x.device.type, dtype=torch.float32): x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) x = x.to(self.mlp[0].weight.dtype) else: x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) return x class HyperCLOVAXSeedVisionAttention(nn.Module): """Multi-head self-attention with 2D RoPE, supporting flash-attention and SDPA.""" def __init__(self, config: HyperCLOVAXSeedVisionEncoderConfig) -> None: super().__init__() self.dim = config.hidden_size self.num_heads = config.num_heads self.head_dim = self.dim // self.num_heads self.num_key_value_groups = 1 # needed for eager attention self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True) self.proj = nn.Linear(self.dim, self.dim) self.scaling = self.head_dim**-0.5 self.config = config self.attention_dropout = 0.0 self.is_causal = False def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: seq_length = hidden_states.shape[0] query_states, key_states, value_states = ( self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) ) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) query_states = query_states.transpose(0, 1).unsqueeze(0) key_states = key_states.transpose(0, 1).unsqueeze(0) value_states = value_states.transpose(0, 1).unsqueeze(0) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] if self.config._attn_implementation == "flash_attention_2": max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() attn_output, _ = attention_interface( self, query_states, key_states, value_states, attention_mask=None, scaling=self.scaling, dropout=0.0 if not self.training else self.attention_dropout, cu_seq_lens_q=cu_seqlens, cu_seq_lens_k=cu_seqlens, max_length_q=max_seqlen, max_length_k=max_seqlen, is_causal=False, **kwargs, ) else: lengths = cu_seqlens[1:] - cu_seqlens[:-1] splits = [ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) ] attn_outputs = [ attention_interface( self, q, k, v, attention_mask=None, scaling=self.scaling, dropout=0.0 if not self.training else self.attention_dropout, is_causal=False, **kwargs, )[0] for q, k, v in zip(*splits) ] attn_output = torch.cat(attn_outputs, dim=1) attn_output = attn_output.reshape(seq_length, -1).contiguous() attn_output = self.proj(attn_output) return attn_output class HyperCLOVAXSeedVisionBlock(GradientCheckpointingLayer): """Transformer block with window or full attention and fp16-safe MLP.""" def __init__( self, config: HyperCLOVAXSeedVisionEncoderConfig, is_fullatt: bool = False, is_last: bool = False, ) -> None: super().__init__() self.norm1 = HyperCLOVAXSeedVisionRMSNorm(config.hidden_size, eps=1e-6) self.norm2 = HyperCLOVAXSeedVisionRMSNorm(config.hidden_size, eps=1e-6) self.attn = HyperCLOVAXSeedVisionAttention(config=config) self.mlp = HyperCLOVAXSeedVisionMLP(config, bias=True) self.is_fullatt = is_fullatt self.is_last = is_last def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, **kwargs, ) # fp16 full-attention blocks and the last block accumulate rounding error # in the MLP; promote to float32 for numerical stability. if ( (not self.is_fullatt and not self.is_last) or hidden_states.dtype != torch.float16 ): hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) else: org_type = hidden_states.dtype with torch.amp.autocast(device_type=hidden_states.device.type, dtype=torch.float32): mlp_out = self.mlp(self.norm2(hidden_states)) if self.is_last: hidden_states = (hidden_states + mlp_out).to(torch.float32) else: hidden_states = (hidden_states + mlp_out).to(org_type) return hidden_states class HyperCLOVAXSeedVisionEncoder(PreTrainedModel): """HyperCLOVAX-SEED Vision Encoder. A spatio-temporal vision transformer that encodes images and videos into token sequences suitable for a causal language model, using window-based and full attention blocks. The encoder outputs merged patch embeddings of shape ``(total_merged_patches, out_hidden_size)`` where ``total_merged_patches = sum(t * h * w / spatial_merge_size^2 for each input)``. """ config_class = HyperCLOVAXSeedVisionEncoderConfig _no_split_modules = ["HyperCLOVAXSeedVisionBlock"] supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True def __init__(self, config: HyperCLOVAXSeedVisionEncoderConfig, *inputs, **kwargs) -> None: super().__init__(config, *inputs, **kwargs) self.spatial_merge_size = config.spatial_merge_size self.patch_size = config.patch_size self.fullatt_block_indexes = config.fullatt_block_indexes self.window_size = config.window_size self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size self.disable_merger = config.disable_merger self.patch_embed = HyperCLOVAXSeedVisionPatchEmbed( patch_size=config.patch_size, temporal_patch_size=config.temporal_patch_size, in_channels=config.in_channels, embed_dim=config.hidden_size, ) head_dim = config.hidden_size // config.num_heads self.rotary_pos_emb = HyperCLOVAXSeedVisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList([ HyperCLOVAXSeedVisionBlock( config, is_fullatt=(_block_idx in config.fullatt_block_indexes), is_last=(_block_idx == config.depth - 1), ) for _block_idx in range(config.depth) ]) self.merger = None if not self.disable_merger: self.merger = HyperCLOVAXSeedVisionPatchMerger( dim=config.out_hidden_size, context_dim=config.hidden_size, spatial_merge_size=config.spatial_merge_size, ) self.gradient_checkpointing = False self.post_init() def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: """Compute 2D rotary position embeddings for all patches in the batch.""" pos_ids = [] for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = hpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ) hpos_ids = hpos_ids.permute(0, 2, 1, 3) hpos_ids = hpos_ids.flatten() wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) wpos_ids = wpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ) wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.flatten() pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb def get_window_index(self, grid_thw: torch.Tensor) -> tuple[torch.Tensor, list]: """Build a flat index that reorders tokens into non-overlapping windows. Returns: window_index: permutation indices to gather tokens in window order cu_window_seqlens: cumulative window sequence lengths for varlen attention """ window_index: list = [] cu_window_seqlens: list = [0] window_index_id = 0 vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size for grid_t, grid_h, grid_w in grid_thw: llm_grid_h, llm_grid_w = ( grid_h // self.spatial_merge_size, grid_w // self.spatial_merge_size, ) index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) index_padded = index_padded.reshape( grid_t, num_windows_h, vit_merger_window_size, num_windows_w, vit_merger_window_size, ) index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( grid_t, num_windows_h * num_windows_w, vit_merger_window_size, vit_merger_window_size, ) seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] window_index.append(index_new + window_index_id) cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index = torch.cat(window_index, dim=0) return window_index, cu_window_seqlens def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: """ Args: hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): Flattened patch pixels (output of patch embedding pipeline before this call). In practice this is the raw pixel tensor passed to `patch_embed` internally. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): Temporal, height and width grid dimensions for each input item. Returns: `torch.Tensor` of shape `(total_merged_patches, out_hidden_size)`. """ hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) window_index, cu_window_seqlens = self.get_window_index(grid_thw) cu_window_seqlens = torch.tensor( cu_window_seqlens, device=hidden_states.device, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for layer_num, blk in enumerate(self.blocks): cu_seqlens_now = cu_seqlens if layer_num in self.fullatt_block_indexes else cu_window_seqlens if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( blk, hidden_states, cu_seqlens_now, None, # rotary_pos_emb (unused; position_embeddings used instead) position_embeddings, ) else: hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings, **kwargs, ) if self.merger is not None: hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) hidden_states = hidden_states[reverse_indices, :] return hidden_states else: # window_index to rearrange patches return hidden_states, window_index AutoModel.register(HyperCLOVAXSeedVisionEncoderConfig, HyperCLOVAXSeedVisionEncoder) __all__ = ["HyperCLOVAXSeedVisionEncoder"]