HyperCLOVAX-SEED-Think-4B / modeling_hyperclovax_seed_vision_v2.py
bigshanedogg's picture
Upload folder using huggingface_hub
0c1d6f8 verified
# coding=utf-8
# Copyright 2026 NAVER Cloud Corp. and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HyperCLOVAX-Vision-V2 multimodal model.
Integrates a vision encoder, vision projector, causal language model, and
optionally an audio encoder. The published model uses:
- Language model: HyperCLOVAX or Llama
- Vision encoder: HyperCLOVAXSeedVisionEncoder + PatchMerger projector
- Audio encoder: HyperCLOVAXSeedAudioEncoder + MLP projector
Acknowledgements:
- VLM integration pattern adapted from LLaVA
(https://github.com/haotian-liu/LLaVA), Apache-2.0 License.
- CAbstractor and weight initialization adapted from Honeybee
(https://github.com/kakaobrain/honeybee), Apache-2.0 License.
- PatchMerger projector adapted from Qwen2.5-VL
(https://github.com/QwenLM/Qwen2.5-VL), Apache-2.0 License.
"""
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from einops import rearrange
from timm.layers import LayerNorm, LayerNorm2d
from timm.models.regnet import RegStage
except ImportError:
pass
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification, PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.cache_utils import Cache
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from .configuration_hyperclovax_seed_vision_v2 import HyperCLOVAXVisionV2Config, ProjectorType
from .configuration_hyperclovax_seed_vision_encoder import HyperCLOVAXSeedVisionEncoderConfig
try:
from transformers import Qwen2_5_VLVisionConfig
except ImportError:
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
class HyperCLOVAXVisionV2MLP(nn.Module):
"""MLP projector for vision features (standard or inverted-bottleneck)."""
def __init__(
self,
vision_projector_type: str,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.vision_projector_type = vision_projector_type
if vision_projector_type == ProjectorType.MLP:
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
elif vision_projector_type == ProjectorType.INVERTED_MLP:
self.fc1 = nn.Linear(in_features, 2 * hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(2 * hidden_features, out_features)
else:
raise NotImplementedError(f"{vision_projector_type} is not implemented")
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
class HyperCLOVAXVisionV2CAbstractor(nn.Module):
"""C-Abstractor: convolutional visual abstractor with adaptive pooling.
Adapted from the C-Abstractor in Honeybee.
Encodes a flattened patch sequence ``(B, L, encoder_hidden_size)`` through
two RegNet stages separated by adaptive average pooling, then projects to
the LLM hidden size via a small MLP readout.
Args:
num_queries: Number of output visual tokens (must be a perfect square).
num_input_tokens: Number of input patch tokens (used for positional embedding).
encoder_hidden_size: Hidden size of the vision encoder output.
hidden_size: Internal channel size of the RegNet stages.
output_hidden_size: Output size (= LLM hidden size).
pos_emb: If ``True``, add a learnable positional embedding to the input.
prenorm: If ``True``, apply LayerNorm before the convolutional stages.
"""
def __init__(
self,
num_queries: int,
num_input_tokens: int,
encoder_hidden_size: int,
hidden_size: int,
output_hidden_size: int,
pos_emb: bool = True,
prenorm: bool = False,
depth: int = 3,
mlp_depth: int = 2,
):
super().__init__()
if not (num_queries ** 0.5).is_integer():
raise ValueError(f"num_queries must be a perfect square, got {num_queries}")
hw = int(num_queries ** 0.5)
self.num_input_tokens = num_input_tokens
self.output_hidden_size = output_hidden_size
self.pos_emb: Optional[nn.Parameter]
if pos_emb:
self.pos_emb = nn.Parameter(torch.zeros(1, num_input_tokens, encoder_hidden_size))
self.pos_emb.data.normal_(mean=0.0, std=0.02)
else:
self.pos_emb = None
self.prenorm = LayerNorm(encoder_hidden_size) if prenorm else None
RegBlock = partial(RegStage, stride=1, dilation=1, act_layer=nn.SiLU, norm_layer=LayerNorm2d)
self.net = nn.Sequential(
RegBlock(depth, encoder_hidden_size, hidden_size),
nn.AdaptiveAvgPool2d((hw, hw)),
RegBlock(depth, hidden_size, hidden_size),
)
layers = [nn.Linear(hidden_size, output_hidden_size)]
for _ in range(1, mlp_depth):
layers.append(nn.SiLU())
layers.append(nn.Linear(output_hidden_size, output_hidden_size))
self.readout = nn.Sequential(*layers)
def forward(
self,
x: torch.Tensor,
num_queries_vis_abstractors: Optional[List[int]] = None,
num_grids: Optional[List[int]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Args:
x: ``(B, L, encoder_hidden_size)`` patch features from the vision backbone.
num_queries_vis_abstractors: Per-image query counts for adaptive pooling.
If ``None``, uses the fixed grid size from ``__init__``.
num_grids: Cumulative grid-boundary indices corresponding to
``num_queries_vis_abstractors``. Required when the above is set.
Returns:
``(B, num_queries, output_hidden_size)`` tensor when using the fixed
grid (``num_queries_vis_abstractors`` is ``None``), or a list of
per-image tensors when using adaptive pooling.
"""
if self.prenorm is not None:
x = self.prenorm(x)
if self.pos_emb is not None:
x = x + self.pos_emb
# Reshape flat patch sequence to spatial grid: [B, L, d] → [B, d, h, w]
hw = int(x.size(1) ** 0.5)
x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)
if num_queries_vis_abstractors is not None:
assert num_grids is not None
return self._forward_adaptive(x, num_queries_vis_abstractors, num_grids)
x = self.net(x)
x = rearrange(x, "b d h w -> b (h w) d")
return self.readout(x)
def _forward_adaptive(
self,
x: torch.Tensor,
num_queries_vis_abstractors: List[int],
num_grids: List[int],
) -> List[torch.Tensor]:
"""Adaptive-query forward: replaces the fixed sampler with per-image pooling."""
# self.net = (s1, fixed_sampler, s2) — apply only s1 here
assert len(self.net) == 3
x = self.net[0](x)
outputs = []
for i, num_queries in enumerate(num_queries_vis_abstractors):
hw = int(num_queries ** 0.5)
out = nn.AdaptiveAvgPool2d((hw, hw))(x[num_grids[i]: num_grids[i + 1], :])
out = self.net[2](out)
out = rearrange(out, "b d h w -> b (h w) d")
outputs.append(self.readout(out))
return outputs
class HyperCLOVAXVisionV2RMSNorm(nn.Module):
"""RMS normalisation layer used inside HyperCLOVAXVisionV2PatchMerger."""
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self) -> str:
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class HyperCLOVAXVisionV2PatchMerger(nn.Module):
"""Patch-merger projector that maps vision tokens to LLM embedding space.
Adapted from the PatchMerger in Qwen2.5-VL.
Accepts a tuple ``(hidden_states, window_index)`` from the vision encoder
(the encoder's built-in merger is bypassed), applies RMSNorm + MLP over the
spatially-merged window, then restores the original token order.
Args:
dim: Output hidden size (= LLM hidden size).
context_dim: Input hidden size (= vision encoder ``out_hidden_size``).
spatial_merge_size: Spatial merge factor used in the vision encoder
(default 2, matching Qwen2.5-VL defaults).
"""
def __init__(
self,
dim: int,
context_dim: int,
spatial_merge_size: int = 2,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size ** 2)
self.ln_q = HyperCLOVAXVisionV2RMSNorm(context_dim, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.GELU(),
nn.Linear(self.hidden_size, dim),
)
def forward(
self,
inputs: Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
"""
Args:
inputs: Tuple of ``(hidden_states, window_index)`` produced by the
monkey-patched Qwen vision encoder forward.
Returns:
Tensor of shape ``(total_tokens, dim)`` in the original token order.
"""
x, window_index = inputs
# fp16 models accumulate rounding error in the linear layers; promote
# to float32 for the merge step (matches vLLM behaviour).
if self.mlp[0].weight.dtype == torch.float16:
with torch.amp.autocast(device_type=x.device.type, dtype=torch.float32):
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
else:
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
reverse_indices = torch.argsort(window_index)
return x[reverse_indices, :]
class HyperCLOVAXVisionV2PreTrainedModel(PreTrainedModel):
"""Base class for all HyperCLOVAX-Vision-V2 models."""
config_class = HyperCLOVAXVisionV2Config
base_model_prefix = "model"
_no_split_modules = ["HyperCLOVAXSeedVisionBlock", "Qwen2DecoderLayer", "LlamaDecoderLayer"]
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(
self,
module: nn.Module,
) -> None:
"""Initialize weights following Honeybee conventions."""
# https://github.com/kakaobrain/honeybee/blob/main/honeybee/common_layers.py#L55
if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.Embedding, nn.Linear)):
module.weight.data.normal_(mean=0.0, std=0.02)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class HyperCLOVAXVisionV2Model(HyperCLOVAXVisionV2PreTrainedModel):
"""Backbone model: vision encoder + multimodal projector + LLM base (no LM head)."""
def __init__(
self,
config: HyperCLOVAXVisionV2Config,
) -> None:
super().__init__(config)
# vision encoder
vision_config = config.vision_config
vision_config.anyres = config.anyres
vision_config.max_num_grids = config.max_num_grids
vision_config.torch_dtype = getattr(config, "dtype", None) or getattr(config, "torch_dtype", None)
self.vision_config = vision_config
if config.anyres:
if not getattr(config, "possible_resolutions", []):
assert config.max_num_grids > 0
possible_resolutions = [
[ys * vision_config.image_size, xs * vision_config.image_size]
for i in range(1, config.max_num_grids + 1)
for j in range(1, config.max_num_grids + 1)
for ys, xs in ([(i, j)] if (i != 1 or j != 1 or config.use_1x1_grid) and i * j <= config.max_num_grids else [])
]
self.config.possible_resolutions = possible_resolutions
else:
self.config.possible_resolutions = config.possible_resolutions
if vision_config.model_type != Qwen2_5_VLVisionConfig.model_type:
vision_config._attn_implementation = config._attn_implementation
if not vision_config.name_or_path:
vision_config._name_or_path = config._name_or_path
self.vision_model = AutoModel.from_config(
vision_config,
trust_remote_code=True,
attn_implementation=config._attn_implementation,
)
# language model
text_config = config.text_config
text_config.torch_dtype = getattr(config, "dtype", None) or getattr(config, "torch_dtype", None)
if text_config.model_type in ["llama", "hyperclovax", "gpt2"]:
text_config._attn_implementation = config._attn_implementation
if text_config.model_type != "hyperclovax":
text_config.logits_scaling = 1.0
text_config.vocab_size = (
text_config.padded_vocab_size if hasattr(text_config, "padded_vocab_size") else text_config.vocab_size
)
self.language_model = AutoModelForCausalLM.from_config(text_config, trust_remote_code=True)
self.text_config = text_config
self.num_queries_vis_abstractor = config.num_queries_vis_abstractor
# vision projector (connector)
input_hidden_size = vision_config.hidden_size
if vision_config.model_type == Qwen2_5_VLVisionConfig.model_type:
input_hidden_size = vision_config.out_hidden_size
if config.vision_projector_type == ProjectorType.LINEAR:
self.mm_projector = nn.Linear(input_hidden_size, text_config.hidden_size)
elif config.vision_projector_type == ProjectorType.CABSTRACTOR:
self.mm_projector = HyperCLOVAXVisionV2CAbstractor(
num_queries=self.num_queries_vis_abstractor,
num_input_tokens=(vision_config.image_size // vision_config.patch_size) ** 2,
encoder_hidden_size=input_hidden_size,
hidden_size=input_hidden_size,
output_hidden_size=text_config.hidden_size,
pos_emb=config.proj_pos_emb,
prenorm=config.proj_prenorm,
)
self.mm_projector.pos_emb.to(config.torch_dtype)
elif config.vision_projector_type == ProjectorType.PATCH_MERGER:
# Custom patch-merger with HyperCLOVAX RMSNorm and fp16 autocast.
# Requires the Qwen vision encoder to be monkey-patched so it returns
# (hidden_states, window_index) instead of applying its built-in merger.
self.mm_projector = HyperCLOVAXVisionV2PatchMerger(
dim=text_config.hidden_size,
context_dim=input_hidden_size,
)
else:
self.mm_projector = HyperCLOVAXVisionV2MLP(
config.vision_projector_type,
input_hidden_size,
hidden_features=input_hidden_size,
out_features=text_config.hidden_size,
)
self.mm_projector.to(config.torch_dtype)
self.vision_feature_layer = config.vision_feature_layer
self.anyres = config.anyres
if self.anyres:
self.image_newline = nn.Parameter(torch.empty(text_config.hidden_size, dtype=self.dtype))
# audio encoder
self.audio_model = None
self.audio_projector = None
if isinstance(getattr(config, "audio_config", None), PretrainedConfig):
audio_config = config.audio_config
audio_config.torch_dtype = getattr(config, "torch_dtype", None)
if not audio_config.name_or_path:
audio_config._name_or_path = config._name_or_path
self.audio_model = AutoModel.from_config(
audio_config,
trust_remote_code=True,
attn_implementation=config._attn_implementation,
)
if config.audio_projector_type == ProjectorType.LINEAR:
self.audio_projector = nn.Linear(
in_features=audio_config.d_model,
out_features=text_config.hidden_size,
)
else:
self.audio_projector = HyperCLOVAXVisionV2MLP(
config.audio_projector_type,
audio_config.d_model,
hidden_features=audio_config.d_model,
out_features=text_config.hidden_size,
)
self.audio_projector.to(self.audio_model.dtype)
def process_audio_input(
self,
audio_values: torch.Tensor,
audio_attention_mask: torch.Tensor,
) -> List[torch.Tensor]:
"""Encode audio chunks into LLM embedding space.
Args:
audio_values: ``(total_chunks, 128, 3000)`` mel spectrogram tensor.
audio_attention_mask: ``(total_chunks, 3000)`` attention mask.
Returns:
List containing one tensor of shape ``(total_chunks * T, hidden_size)``.
"""
emb = self.audio_model(
audio_values,
attention_mask=audio_attention_mask,
).last_hidden_state # (total_chunks, T, d_model)
emb = emb.flatten(0, 1) # (total_chunks * T, d_model)
emb = self.audio_projector(emb)
return [emb]
def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings()
def set_input_embeddings(
self,
value: nn.Embedding,
) -> None:
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Linear:
return self.language_model.get_output_embeddings()
def set_output_embeddings(
self,
new_embeddings: nn.Linear,
) -> None:
self.language_model.set_output_embeddings(new_embeddings)
def get_decoder(self) -> nn.Module:
return self.language_model.get_decoder()
def set_decoder(
self,
decoder: nn.Module,
) -> None:
self.language_model.set_decoder(decoder)
def tie_weights(
self,
**kwargs,
) -> None:
# Under device_map="auto", embed_tokens and lm_head may land on different
# CUDA devices. The new transformers tie_weights() calls torch.equal() on
# both tensors before deciding whether to tie them, which raises RuntimeError
# when the tensors are on different devices. Move lm_head.weight to the
# same device as embed_tokens.weight beforehand so the comparison succeeds.
if getattr(self.config.text_config, "tie_word_embeddings", False):
input_embeddings = self.language_model.get_input_embeddings()
output_embeddings = self.language_model.get_output_embeddings()
if (
input_embeddings is not None
and output_embeddings is not None
and input_embeddings.weight.device != output_embeddings.weight.device
):
output_embeddings.weight = nn.Parameter(output_embeddings.weight.to(input_embeddings.weight.device))
return self.language_model.tie_weights(**kwargs)
def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
self.config.text_config.vocab_size = model_embeds.num_embeddings
return model_embeds
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# audio inputs (from processor)
audio_values: Optional[torch.FloatTensor] = None,
audio_attention_mask: Optional[torch.FloatTensor] = None,
audio_masks: Optional[List[torch.Tensor]] = None, # reserved; not used in forward
num_audio_tokens: Optional[torch.LongTensor] = None, # reserved; not used in forward
# vision inputs (from processor)
image_grid_thw: Optional[torch.LongTensor] = None,
num_image_tokens: Optional[torch.LongTensor] = None, # reserved; not used in forward
# video inputs (from processor)
pixel_values_videos: Optional[torch.FloatTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
num_video_tokens: Optional[torch.LongTensor] = None, # reserved; not used in forward
video_audio_values: Optional[torch.FloatTensor] = None,
video_audio_attention_mask: Optional[torch.FloatTensor] = None,
video_audio_masks: Optional[List[torch.Tensor]] = None, # reserved; not used in forward
num_video_audio_tokens: Optional[torch.LongTensor] = None, # reserved; not used in forward
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Fuse multimodal inputs into token embeddings and run the language model backbone.
Image, video, and audio tokens identified by their respective token IDs in
``input_ids`` are replaced with the corresponding encoder+projector outputs
before being passed to the language model.
Returns:
``BaseModelOutputWithPast`` (or tuple when ``return_dict=False``).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if audio_values is not None:
raise ValueError(
"Standalone audio input (`audio_values`) is not supported by this model. "
"Audio is only supported as part of video input (`video_audio_values`)."
)
if inputs_embeds is None:
# With device_map="auto", accelerate hooks may have an stale execution_device
# that differs from the actual weight device (e.g. due to tied embeddings).
# Bypass the hook by calling F.embedding directly so that input and weight
# are guaranteed to be on the same device.
embed_module = self.get_input_embeddings()
inputs_embeds = F.embedding(
input_ids.to(embed_module.weight.device),
embed_module.weight,
embed_module.padding_idx,
)
if pixel_values is not None:
image_features = self.process_image_input(
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
positions = input_ids.eq(self.config.image_token_id).nonzero(as_tuple=False)
inputs_embeds[positions[:, 0], positions[:, 1]] = (
torch.cat(image_features).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
)
if pixel_values_videos is not None:
video_features = self.process_video_input(
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
)
positions = input_ids.eq(self.config.video_token_id).nonzero(as_tuple=False)
inputs_embeds[positions[:, 0], positions[:, 1]] = (
torch.cat(video_features).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
)
if video_audio_values is not None and self.audio_model is not None:
video_audio_token_id = getattr(self.config, "video_audio_token_id", None)
if video_audio_token_id is not None:
video_audio_features = self.process_audio_input(
audio_values=video_audio_values,
audio_attention_mask=video_audio_attention_mask,
)
positions = input_ids.eq(video_audio_token_id).nonzero(as_tuple=False)
inputs_embeds[positions[:, 0], positions[:, 1]] = (
torch.cat(video_audio_features).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
)
input_ids = None
return self.language_model.base_model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def process_image_input(
self,
pixel_values: torch.FloatTensor,
image_grid_thw: Optional[torch.LongTensor] = None,
) -> List[torch.Tensor]:
"""Encode image pixel values into LLM-space feature tensors.
Args:
pixel_values: Flat tensor of shape ``(total_patches, channels * patch_size * patch_size)``.
image_grid_thw: Grid shape ``(num_images, 3)`` with (T, H, W) per image.
Returns:
List containing one tensor of shape ``(total_image_tokens, hidden_size)``.
"""
features = self.vision_model(pixel_values, grid_thw=image_grid_thw)
features = self.mm_projector(features)
return [features]
def process_video_input(
self,
pixel_values_videos: torch.FloatTensor,
video_grid_thw: Optional[torch.LongTensor] = None,
) -> List[torch.Tensor]:
"""Encode video pixel values into LLM-space feature tensors.
Args:
pixel_values_videos: Flat tensor of shape ``(total_patches, channels * patch_size * patch_size)``.
video_grid_thw: Grid shape ``(num_videos, 3)`` with (T, H, W) per video.
Returns:
List containing one tensor of shape ``(total_video_tokens, hidden_size)``.
"""
features = self.vision_model(pixel_values_videos, grid_thw=video_grid_thw)
features = self.mm_projector(features)
return [features]
class HyperCLOVAXVisionV2ForCausalLM(HyperCLOVAXVisionV2PreTrainedModel, GenerationMixin):
"""HyperCLOVAX-Vision-V2 model with a causal language modelling head."""
def __init__(
self,
config: HyperCLOVAXVisionV2Config,
) -> None:
super().__init__(config)
self.model = HyperCLOVAXVisionV2Model(config)
self.post_init()
# Delegate embedding / decoder accessors to the inner model
def get_input_embeddings(self) -> nn.Embedding:
return self.model.get_input_embeddings()
def set_input_embeddings(
self,
value: nn.Embedding,
) -> None:
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Linear:
return self.model.get_output_embeddings()
def set_output_embeddings(
self,
new_embeddings: nn.Linear,
) -> None:
self.model.set_output_embeddings(new_embeddings)
def get_decoder(self) -> nn.Module:
return self.model.get_decoder()
def set_decoder(
self,
decoder: nn.Module,
) -> None:
self.model.set_decoder(decoder)
def tie_weights(
self,
**kwargs,
) -> None:
return self.model.tie_weights(**kwargs)
def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
) -> nn.Embedding:
return self.model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# Convenience properties
@property
def language_model(self) -> nn.Module:
return self.model.language_model
@property
def vision_model(self) -> nn.Module:
return self.model.vision_model
@property
def mm_projector(self) -> nn.Module:
return self.model.mm_projector
@property
def audio_model(self) -> Optional[nn.Module]:
return self.model.audio_model
@property
def audio_projector(self) -> Optional[nn.Module]:
return self.model.audio_projector
@property
def vision_model_type(self) -> str:
return self.model.vision_config.model_type
@property
def anyres(self) -> bool:
return self.model.anyres
@property
def image_newline(self) -> Optional[nn.Parameter]:
return self.model.image_newline
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# audio inputs (from processor)
audio_values: Optional[torch.FloatTensor] = None,
audio_attention_mask: Optional[torch.FloatTensor] = None,
audio_masks: Optional[List[torch.Tensor]] = None, # reserved; not used in forward
num_audio_tokens: Optional[torch.LongTensor] = None, # reserved; not used in forward
# vision inputs (from processor)
image_grid_thw: Optional[torch.LongTensor] = None,
num_image_tokens: Optional[torch.LongTensor] = None, # reserved; not used in forward
# video inputs (from processor)
pixel_values_videos: Optional[torch.FloatTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
num_video_tokens: Optional[torch.LongTensor] = None, # reserved; not used in forward
video_audio_values: Optional[torch.FloatTensor] = None,
video_audio_attention_mask: Optional[torch.FloatTensor] = None,
video_audio_masks: Optional[List[torch.Tensor]] = None, # reserved; not used in forward
num_video_audio_tokens: Optional[torch.LongTensor] = None, # reserved; not used in forward
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Multimodal causal language model forward pass.
Calls the backbone model to fuse multimodal inputs, then computes logits
via the LM head. Loss is computed against ``labels`` when provided.
Returns:
``CausalLMOutputWithPast`` (or tuple when ``return_dict=False``).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model.forward(
input_ids=input_ids,
pixel_values=pixel_values,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
token_type_ids=token_type_ids,
use_cache=use_cache,
cache_position=cache_position,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
audio_values=audio_values,
audio_attention_mask=audio_attention_mask,
image_grid_thw=image_grid_thw,
num_image_tokens=num_image_tokens,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
num_video_tokens=num_video_tokens,
video_audio_values=video_audio_values,
video_audio_attention_mask=video_audio_attention_mask,
video_audio_masks=video_audio_masks,
num_video_audio_tokens=num_video_audio_tokens,
)
hidden_states = outputs[0]
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.model.language_model.lm_head(hidden_states[:, slice_indices, :]) * getattr(
self.config.text_config, "logits_scaling", 1.0
)
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
audio_values: Optional[torch.FloatTensor] = None,
audio_attention_mask: Optional[torch.FloatTensor] = None,
video_audio_values: Optional[torch.FloatTensor] = None,
video_audio_attention_mask: Optional[torch.FloatTensor] = None,
**kwargs: Any,
) -> Dict[str, Any]:
# Overwritten -- multimodal inputs are declared as explicit named params
# so they are naturally excluded from **kwargs and do not leak into super().
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
**kwargs,
)
# Prefill detection: no past KV cache yet.
# - transformers 4.x: past_key_values is None
# - transformers 5.x: pre-creates an empty DynamicCache, so get_seq_length() == 0
is_prefill = past_key_values is None or past_key_values.get_seq_length() == 0
if is_prefill:
model_inputs["pixel_values"] = pixel_values
model_inputs["image_grid_thw"] = image_grid_thw
model_inputs["pixel_values_videos"] = pixel_values_videos
model_inputs["video_grid_thw"] = video_grid_thw
model_inputs["audio_values"] = audio_values
model_inputs["audio_attention_mask"] = audio_attention_mask
model_inputs["video_audio_values"] = video_audio_values
model_inputs["video_audio_attention_mask"] = video_audio_attention_mask
return model_inputs
class HyperCLOVAXVisionV2ForSequenceClassification(HyperCLOVAXVisionV2PreTrainedModel):
"""HyperCLOVAX-Vision-V2 model with a sequence classification head."""
def __init__(
self,
config: HyperCLOVAXVisionV2Config,
) -> None:
super().__init__(config)
self.num_labels = getattr(config, "num_labels", 2)
self.model = HyperCLOVAXVisionV2Model(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.get_input_embeddings()
def set_input_embeddings(
self,
value: nn.Embedding,
) -> None:
self.model.set_input_embeddings(value)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
# vision inputs (from processor)
image_grid_thw: Optional[torch.LongTensor] = None,
num_image_tokens: Optional[torch.LongTensor] = None, # reserved; not used in forward
# video inputs (from processor)
pixel_values_videos: Optional[torch.FloatTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
num_video_tokens: Optional[torch.LongTensor] = None, # reserved; not used in forward
) -> SequenceClassifierOutputWithPast:
"""
Sequence classification forward pass.
Extracts the last non-padding token's hidden state, projects it via
``self.score``, and computes loss against ``labels`` when provided.
Returns:
``SequenceClassifierOutputWithPast``.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
token_type_ids=token_type_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
image_grid_thw=image_grid_thw,
num_image_tokens=num_image_tokens,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
num_video_tokens=num_video_tokens,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None or input_ids is None:
last_non_pad_token = -1
else:
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
AutoConfig.register("hyperclovax_vision_v2", HyperCLOVAXVisionV2Config)
AutoModel.register(HyperCLOVAXVisionV2Config, HyperCLOVAXVisionV2Model)
AutoModelForCausalLM.register(HyperCLOVAXVisionV2Config, HyperCLOVAXVisionV2ForCausalLM)
AutoModelForSequenceClassification.register(HyperCLOVAXVisionV2Config, HyperCLOVAXVisionV2ForSequenceClassification)
__all__ = [
"HyperCLOVAXVisionV2PreTrainedModel",
"HyperCLOVAXVisionV2Model",
"HyperCLOVAXVisionV2ForCausalLM",
"HyperCLOVAXVisionV2ForSequenceClassification",
]