HyperCLOVAX-SEED-Omni-8B / modeling_vlm.py
PenPaperKeyCode's picture
Init
3169f6c
import contextlib
import math
import os
from functools import partial
from itertools import chain
from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
try:
from einops import rearrange
from timm.layers import LayerNorm, LayerNorm2d
from timm.models.regnet import RegStage
except:
print("packages needed for anyres are not imported")
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
)
from transformers.cache_utils import Cache
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import no_init_weights
from .configuration_vlm import HCXVisionConfig
try:
from .cosyvoice import (
DEFAULT_SAMPLE_RATE,
MIN_DISCRETE_AUDIO_CHUNK_SAMPLES,
CosyvoiceEncoder,
)
except:
print("packages needed for discrete audio are not imported")
try:
from .ta_tok import TextAlignedTokenizer
except:
print("packages needed for discrete vision are not imported")
try:
from .mambamia_videoaudio_compressor import (
MambaMiaVideoAudioCompressor,
MambaMiaVideoAudioCompressorConfig,
)
except:
print("packages needed for mambamia video-audio compressor are not imported")
def get_rank():
if dist.is_initialized():
return dist.get_rank()
return 0
def is_ampere_or_newer():
if not torch.cuda.is_available():
return False
gpu_name = torch.cuda.get_device_name()
ampere_keywords = [
"RTX 30",
"RTX 40",
"A100",
"H100",
"A6000",
"A5000",
"A4000",
"A3000",
"A2000",
"A1000",
]
return any(keyword in gpu_name for keyword in ampere_keywords)
EOT = "<|endofturn|>"
IMG_LOC = "<|IMAGE_PAD|>"
class HCXVisionPreTrainedModel(PreTrainedModel):
config_class = HCXVisionConfig
base_model_prefix = "model"
vision_model_name = "vision_model"
_no_split_modules = [
"CLIPAttention",
"SiglipVisionModel",
]
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
if (
isinstance(module, nn.Conv2d)
or isinstance(module, nn.Embedding)
or isinstance(module, nn.Linear)
):
module.weight.data.normal_(mean=0.0, std=0.02)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Parameter):
embed_std = 1 / torch.sqrt(
torch.tensor(module.size(0), dtype=torch.float)
).to(module.dtype)
module.data.normal_(mean=0.0, std=embed_std)
class HCXVisionModel(HCXVisionPreTrainedModel):
def __init__(
self,
config: HCXVisionConfig,
without_llm=False,
**kwargs,
):
super().__init__(config)
self.flag_changed_max_position_embeddings = False
self.without_llm = without_llm
vision_model_type = config.vision_config.model_type
self.is_qwen_visual = False
if vision_model_type == "qwen2_5_vl_visual":
self.is_qwen_visual = True
self.freeze_before_sampler = kwargs.pop("freeze_before_sampler", False)
vision_config = config.vision_config
vision_config.anyres = config.anyres
vision_config.max_num_grids = config.max_num_grids
vision_config.update({"torch_dtype": config.torch_dtype})
self.vision_config = vision_config
if without_llm:
vision_config.vison_pretrained_name_or_path = (
config.vision_model_name_or_path
)
if self.is_qwen_visual and is_ampere_or_newer():
vision_config._attn_implementation = "flash_attention_2"
self.vision_model = AutoModel.from_config(vision_config, trust_remote_code=True)
if not self.config.freeze_encoder:
self.vision_model.gradient_checkpointing_enable()
else:
self.vision_model.eval()
if vision_config.model_type == "qwen2_5_vl_visual":
import torch.nn.functional as F
def new_block_forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
)
if hidden_states.dtype == torch.float16:
org_type = hidden_states.dtype
with torch.amp.autocast(device_type="cuda", dtype=torch.float32):
hidden_states = hidden_states + self.mlp(
self.norm2(hidden_states)
)
hidden_states = hidden_states.to(org_type)
return hidden_states
def new_last_block_forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
)
if hidden_states.dtype == torch.float16:
with torch.amp.autocast(device_type="cuda", dtype=torch.float32):
hidden_states = hidden_states + self.mlp(
self.norm2(hidden_states)
)
hidden_states = hidden_states.to(torch.float32)
return hidden_states
def new_merger_forward(self, x: torch.Tensor) -> torch.Tensor:
if self.mlp[0].weight.dtype == torch.float16:
with torch.amp.autocast(device_type="cuda", dtype=torch.float32):
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
x = x.to(self.mlp[0].weight.dtype)
else:
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
return x
import types
for b_idx, blk in enumerate(self.vision_model.blocks):
if b_idx == len(self.vision_model.blocks) - 1:
blk.forward = types.MethodType(
new_last_block_forward, self.vision_model.blocks[b_idx]
)
elif b_idx in vision_config.fullatt_block_indexes:
blk.forward = types.MethodType(
new_block_forward, self.vision_model.blocks[b_idx]
)
self.vision_model.merger.forward = types.MethodType(
new_merger_forward, self.vision_model.merger
)
if config.mm_projector_type == "qwen_merger":
import torch.nn.functional as F
def new_forward(
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
The final hidden states of the model.
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
The temporal, height and width of feature shape of each image in LLM.
Returns:
`torch.Tensor`: hidden_states.
"""
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(
dim=0,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
blk.__call__,
hidden_states,
cu_seqlens_now,
None,
position_embeddings,
)
else:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
position_embeddings=position_embeddings,
)
return hidden_states, window_index
import types
self.vision_model.forward = types.MethodType(new_forward, self.vision_model)
self.vision_model.merger = nn.Identity()
self.discrete_vision_model = None
discrete_vision_config = config.discrete_vision_config
if discrete_vision_config is not None and isinstance(
discrete_vision_config, dict
):
discrete_vision_config.update({"torch_dtype": config.torch_dtype})
self.discrete_vision_config = discrete_vision_config
if (
hasattr(config, "discrete_vision_config")
and config.discrete_vision_config is not None
):
try:
assert (
discrete_vision_config["model_type"] == "ta_tok"
), "Only 'ta_tok' discrete vision model is supported currently."
self.discrete_vision_model = TextAlignedTokenizer.from_checkpoint(
discrete_vision_config["model_name_or_path"],
load_teacher=False,
input_type="indices",
)
if self.discrete_vision_model is not None:
self.discrete_vision_model.eval()
except Exception as e:
print(f"Warning: Failed to initialize discrete vision model: {e}")
self.discrete_vision_model = None
self.audio_model = None
audio_config = config.audio_config
if audio_config is not None and isinstance(audio_config, dict):
audio_config.update({"torch_dtype": config.torch_dtype})
self.audio_config = audio_config
if hasattr(config, "audio_config") and config.audio_config is not None:
self.audio_model = (
AutoModel.from_config(audio_config)
if audio_config is not None
else None
)
if self.audio_model is not None:
self.audio_model.eval()
self.discrete_audio_model = None
discrete_audio_config = config.discrete_audio_config
if discrete_audio_config is not None and isinstance(
discrete_audio_config, dict
):
discrete_audio_config.update({"torch_dtype": config.torch_dtype})
self.discrete_audio_config = discrete_audio_config
if (
hasattr(config, "discrete_audio_config")
and config.discrete_audio_config is not None
):
try:
assert (
discrete_audio_config["model_type"] == "cosyvoice2"
), "Only 'cosyvoice2' discrete audio model is supported currently."
self.discrete_audio_model = CosyvoiceEncoder.from_pretrained(
discrete_audio_config["model_name_or_path"]
)
if self.discrete_audio_model is not None:
self.discrete_audio_model.eval()
except Exception as e:
print(f"Warning: Failed to initialize discrete audio model: {e}")
self.discrete_audio_model = None
if hasattr(config, "text_config") and config.text_config is not None:
text_config = config.text_config
else:
raise ValueError("text_config is not defined")
text_config.update({"torch_dtype": config.torch_dtype})
if config.text_config.model_type in ["llama", "hyperclovax", "gpt2"]:
text_config._attn_implementation = config._attn_implementation
if text_config.model_type != "hyperclovax":
text_config.logits_scaling = 1.0
text_config.vocab_size = (
text_config.padded_vocab_size
if hasattr(text_config, "padded_vocab_size")
else text_config.vocab_size
)
if not without_llm:
with no_init_weights():
self.language_model = AutoModelForCausalLM.from_config(
text_config, trust_remote_code=True
)
if config.text_config.model_type in ["llama", "hyperclovax", "gpt2"]:
if not self.config.freeze_decoder:
self.language_model.gradient_checkpointing_enable()
if (
config.text_config.model_type == "hyperclovax"
and hasattr(self, "use_liger")
and self.use_liger
):
self.language_model._get_apply_liger_kernel_converter()(
model=self.language_model
)
print("liger kernel for hcx 24b will be used")
self.num_queries_vis_abstractor = config.num_queries_vis_abstractor
input_hidden_size = vision_config.hidden_size
if vision_config.model_type == "qwen2_5_vl_visual":
input_hidden_size = vision_config.out_hidden_size
if config.mm_projector_type == "linear":
self.mm_projector = nn.Linear(input_hidden_size, text_config.hidden_size)
elif config.mm_projector_type == "cabstractor":
self.mm_projector = CAbstractor(
num_queries=self.num_queries_vis_abstractor,
num_input_tokens=(
self.vision_config.image_size // self.vision_config.patch_size
)
** 2,
encoder_hidden_size=input_hidden_size,
hidden_size=input_hidden_size,
output_hidden_size=text_config.hidden_size,
pos_emb=config.proj_pos_emb,
prenorm=config.proj_prenorm,
)
self.mm_projector.pos_emb.to(config.torch_dtype)
elif config.mm_projector_type == "qwen_merger":
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLPatchMerger,
)
self.mm_projector = Qwen2_5_VLPatchMerger(
dim=text_config.hidden_size, context_dim=input_hidden_size
)
def new_forward(self, inputs) -> torch.Tensor:
x, window_index = inputs
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
reverse_indices = torch.argsort(window_index)
x = x[reverse_indices, :]
return x
self.mm_projector.forward = types.MethodType(new_forward, self.mm_projector)
else:
self.mm_projector = VLM_Mlp(
config.mm_projector_type,
input_hidden_size,
hidden_features=input_hidden_size,
out_features=text_config.hidden_size,
)
if audio_config is None:
self.audio_projector = None
else:
if config.audio_projector_type == "linear":
self.audio_projector = nn.Linear(
audio_config.d_model, text_config.hidden_size
)
else:
self.audio_projector = VLM_Mlp(
config.audio_projector_type,
audio_config.d_model,
hidden_features=audio_config.d_model,
out_features=text_config.hidden_size,
)
self.video_audio_compressor = None
if (
audio_config is not None
and config.video_audio_compressor_type == "mambamia"
):
compressor_config = MambaMiaVideoAudioCompressorConfig(
input_size=text_config.hidden_size,
output_size=text_config.hidden_size,
chunk_size=25,
num_hidden_layers=1,
)
self.video_audio_compressor = MambaMiaVideoAudioCompressor(
compressor_config
)
self.video_audio_compressor = self.video_audio_compressor.to(
config.torch_dtype
)
self.use_nth_layer = config.use_nth_layer
self.model_parallel = False
self.device_map = None
self.vision_model_use_no_grad = None
self.text_config = text_config
self.anyres = False
self.unpad = config.unpad
self.vision_input_chunk_size = kwargs.pop("vision_input_chunk_size", None)
self.is_safetensor_save = kwargs.get("is_safetensor_save", True)
self._backward_compatibility_gradient_checkpointing()
self.mm_projector.to(config.torch_dtype)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
discrete_pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True,
image_sizes: Optional[List[List[List[int]]]] = None,
mm_query_lengths: Optional[List[List[int]]] = None,
non_mm_query_lengths: Optional[List[List[int]]] = None,
img_start_ids_list: Optional[List[List[int]]] = None,
num_queries_vis_abstractors: Optional[List[List[int]]] = None,
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None,
first_last_frames_slows: Optional[List[List[bool]]] = None,
is_videos: Optional[List[List[bool]]] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
video_audio_values: Optional[torch.FloatTensor] = None,
video_audio_masks: Optional[torch.FloatTensor] = None,
audio_values: Optional[torch.FloatTensor] = None,
discrete_audio_values: Optional[torch.LongTensor] = None,
discrete_audio_value_num_per_sample: Optional[torch.LongTensor] = None,
audio_masks: Optional[torch.FloatTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
:param input_ids: torch.int64 : torch.size([batchsize, variable)]) : SystemPrompt with Question text token indices for tokenizer.
In positions where images are inputted, the value is replaced by config.img_start_id, which is a vocabulary index used to indicate the start of image data.
:param pixel_values: List of List of 4D tensor (torch.float32)
Each outer list corresponds to a batch and contains inner lists, each holding tensors for images in a sample. The structure accounts for samples with multiple images.
:param past_key_values: None
:param inputs_embeds: None
:param labels: Optional[torch.int64] : [batchsize, variable (input_ids.size(1)+ num visual tokens)] visual token 들은 모두 IGNORE_INDEX
:param use_cache: None
:param output_attentions: Optional[bool] : get attention weights of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함)
:param output_hidden_states: Optional[bool] : get hidden states of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함)
:param image_sizes: Stacked as a List of List, representing image sizes (width, height).
In cases where a sample contains no images, a single dummy image is included.
:param mm_query_lengths: A List of List that stores the lengths when each image is converted into visual tokens for LLM input.
In cases where a sample does not contain any images, an empty list is included.
:param non_mm_query_lengths: contains the lengths of text tokens (excluding visual tokens) for each sample in a batch.
:img_start_ids_list: contains the indices of the img_start_id tokens for each sample.
:num_queries_vis_abstractors: A List of List that contains the number of visual tokens for each image grid.
:num_queries_vis_abstractors_slow: A List of List that contains the number of visual tokens for the slow part when applying the slowfast algorithm to video frames. If the slowfast algorithm is not applied, it will have a value of None.
:first_last_frames_slows: A List of List that contains the only first and last frames slow mode for each sample in a batch.
:is_videos: A List of List that contains the boolean value indicating whether each sample in a batch is a video.
:image_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
:pixel_values_videos: A 2D tensor (torch.float32) for qwen2.5-vl visual encoder.
:video_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
:audio_values: TBD
:discrete_audio_values: TBD
:audio_masks: TBD
:return:
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.vision_config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.vision_config.output_hidden_states
)
if inputs_embeds is None and past_key_values is None:
inputs_embeds, labels = self.extract_inputs_embeds(
input_ids=input_ids,
labels=labels,
pixel_values=pixel_values,
discrete_pixel_values=discrete_pixel_values,
past_key_values=past_key_values,
image_sizes=image_sizes,
mm_query_lengths=mm_query_lengths,
non_mm_query_lengths=non_mm_query_lengths,
img_start_ids_list=img_start_ids_list,
num_queries_vis_abstractors=num_queries_vis_abstractors,
num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow,
first_last_frames_slows=first_last_frames_slows,
is_videos=is_videos,
image_grid_thw=image_grid_thw,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
video_audio_values=video_audio_values,
video_audio_masks=video_audio_masks,
audio_values=audio_values,
discrete_audio_values=discrete_audio_values,
discrete_audio_value_num_per_sample=discrete_audio_value_num_per_sample,
audio_masks=audio_masks,
)
if inputs_embeds is not None:
input_ids = None
outputs = self.language_model.base_model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
labels=labels,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return outputs
def determine_non_mm_query_lengths(self, input_ids, pad_id, img_start_id):
"""non_mm_query_lengths 를 계산하는 함수
input_ids 가 collate 될때, 오른쪽에 pad_id 가 채워지기 때문에 이 값을 찾는 방식을 통해 계산됨
또한 img_start_id 는 visual token 이 들어서는 자리이기 때문에, 해당 indices 은 제거
"""
non_mm_query_lengths = []
batch_size, len_seq = input_ids.size(0), input_ids.size(1)
for i in range(batch_size):
temp_idx = (input_ids[i] == pad_id).nonzero()
eos_idx = temp_idx[0, 0].item() if len(temp_idx) > 0 else len_seq
num_imgs = (input_ids[i] == img_start_id).sum().item()
non_mm_query_lengths.append(eos_idx - num_imgs)
if all([pad_id in input_id for input_id in input_ids.tolist()]):
non_mm_query_lengths = [
non_mm_query_length + 1 for non_mm_query_length in non_mm_query_lengths
]
return non_mm_query_lengths
def determine_mm_query_lengths(self, image_features, image_cnts):
"""mm_query_lengths 를 계산하는 함수
image_features tensor 의 shape 을 통해 계산된다.
이미지가 1장도 없는 sample 의 경우 dummy image 1장이 들어가기 때문에, 따로 빈 list 처리 또한 추가
"""
mm_query_lengths = [
[image_feature.size(0) for image_feature in image_feature_list]
for image_feature_list in image_features
]
for i, image_cnt in enumerate(image_cnts):
if image_cnt == 0:
assert len(mm_query_lengths[i]) == 1
mm_query_lengths[i] = []
return mm_query_lengths
def get_input_embeddings(self):
if self.without_llm:
return None
else:
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
if self.without_llm:
return None
else:
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
if self.without_llm:
return None
else:
return self.language_model.tie_weights()
def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None
) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(
new_num_tokens, pad_to_multiple_of
)
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
def extract_inputs_embeds(
self,
input_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
discrete_pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
image_sizes: Optional[List[List[List[int]]]] = None,
mm_query_lengths: Optional[List[List[int]]] = None,
non_mm_query_lengths: Optional[List[int]] = None,
img_start_ids_list: Optional[List[List[int]]] = None,
num_queries_vis_abstractors: Optional[List[List[int]]] = None,
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None,
first_last_frames_slows: Optional[List[List[bool]]] = None,
is_videos: Optional[List[List[bool]]] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
video_audio_values: Optional[torch.FloatTensor] = None,
video_audio_masks: Optional[torch.FloatTensor] = None,
audio_values: Optional[torch.FloatTensor] = None,
discrete_audio_values: Optional[torch.FloatTensor] = None,
discrete_audio_value_num_per_sample: Optional[torch.LongTensor] = None,
audio_masks: Optional[torch.FloatTensor] = None,
):
"""
:param input_ids: torch.int64 : torch.size([batchsize, variable)]) : SystemPrompt with Question text token indices for tokenizer.
In positions where images are inputted, the value is replaced by config.img_start_id, which is a vocabulary index used to indicate the start of image data.
In cases where a sample contains no images, a single dummy image is included.
:param pixel_values: List of List of 4D tensor (torch.float32)
Each outer list corresponds to a batch and contains inner lists, each holding tensors for images in a sample. The structure accounts for samples with multiple images.
:param past_key_values: None : (batch_size, num_heads, sequence_length - 1, embed_size_per_head): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
:param image_sizes: Stacked as a List of List, representing image sizes (width, height).
In cases where a sample contains no images, a single dummy image is included.
:param mm_query_lengths: A List of List that stores the lengths when each image is converted into visual tokens for LLM input.
In cases where a sample does not contain any images, an empty list is included.
:param non_mm_query_lengths: contains the lengths of text tokens (excluding visual tokens) for each sample in a batch.
:img_start_ids_list: contains the indices of the img_start_id tokens for each sample.
:num_queries_vis_abstractors: A List of List that contains the number of visual tokens for each image grid.
:num_queries_vis_abstractors_slow: A List of List that contains the number of visual tokens for the slow part when applying the slowfast algorithm to video frames. If the slowfast algorithm is not applied, it will have a value of None.
:first_last_frames_slows: A List of bool that contains the information of whether the slowfast algorithm is applied to the first or last frames of the video.
:is_videos: A List of List that contains the boolean value indicating whether each sample in a batch is a video.
:image_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
:pixel_values_videos: A 2D tensor (torch.float32) for qwen2.5-vl visual encoder.
:video_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
:audio_values: TBD
:discrete_audio_values: TBD
:audio_masks: TBD
:return:
"""
inputs_embeds = None
if past_key_values:
pass
else:
inputs_embeds = self.get_input_embeddings()(input_ids)
context_vision_model = (
torch.no_grad()
if self.config.freeze_encoder
else contextlib.nullcontext()
)
if self.config.freeze_encoder:
self.vision_model.eval()
self.vision_model.gradient_checkpointing = False
if (
pixel_values is not None
and len([i for i in pixel_values if i is not None]) > 0
):
with context_vision_model:
image_features = self.vision_model(
pixel_values, grid_thw=image_grid_thw
)
image_features = self.mm_projector(image_features)
if img_start_ids_list is None:
image_cnts = (
(input_ids == self.config.img_start_id).sum(dim=1).tolist()
)
else:
image_cnts = [
len(img_start_ids) for img_start_ids in img_start_ids_list
]
mask = input_ids.eq(self.config.img_start_id)
positions = mask.nonzero(as_tuple=False)
batch_idx = positions[:, 0]
seq_idx = positions[:, 1]
if sum(image_cnts) == 0:
image_features = image_features[0:0]
inputs_embeds[batch_idx, seq_idx, :] = image_features
if (
pixel_values_videos is not None
and len([i for i in pixel_values_videos if i is not None]) > 0
):
with context_vision_model:
video_features = self.vision_model(
pixel_values_videos, grid_thw=video_grid_thw
)
video_features = self.mm_projector(video_features)
video_cnts = (
(input_ids == self.config.video_start_id).sum(dim=1).tolist()
)
mask = input_ids.eq(self.config.video_start_id)
positions = mask.nonzero(as_tuple=False)
batch_idx = positions[:, 0]
seq_idx = positions[:, 1]
if sum(video_cnts) == 0:
inputs_embeds[0, 0:0] = video_features[0:0]
else:
inputs_embeds[batch_idx, seq_idx, :] = video_features
del video_features, mask, positions, batch_idx, seq_idx
if (
discrete_pixel_values is not None
and len([i for i in discrete_pixel_values if i is not None]) > 0
):
assert (
self.config.discrete_image_start_id is not None
), "discrete_image_start_id 가 정의되어 있지 않습니다"
discrete_image_cnt = (
(input_ids == self.config.discrete_image_start_id)
.sum(dim=1)
.tolist()
)
with torch.no_grad():
discrete_unit_lists = self.discrete_vision_model(
discrete_pixel_values
)["encoded"]
discrete_unit_lists = discrete_unit_lists.detach().cpu()
import os
rank = int(os.environ.get("RANK", -1))
if (discrete_unit_lists < 0).any() or (
discrete_unit_lists > 65535
).any():
min_val = discrete_unit_lists.min().item()
max_val = discrete_unit_lists.max().item()
print(f"[RANK {rank}] ❌ discrete_unit_lists has invalid values!")
print(f" min={min_val}, max={max_val}")
print(
f" Expected range: [0, 65535] for TA-Tok codebook (2^16 = 65536 codes)"
)
print(f" Will map to tokens: <|vision00000|> ~ <|vision65535|>")
print(f" Clamping to valid range...")
discrete_unit_lists = torch.clamp(discrete_unit_lists, 0, 65535)
discrete_image_token_ids = (
discrete_unit_lists + self.config.discrete_image_unit_0_id
)
discrete_image_token_ids = discrete_image_token_ids.to(
device=input_ids.device
).to(dtype=input_ids.dtype)
vocab_size = self.get_input_embeddings().num_embeddings
if (discrete_image_token_ids >= vocab_size).any() or (
discrete_image_token_ids < 0
).any():
max_id = discrete_image_token_ids.max().item()
min_id = discrete_image_token_ids.min().item()
print(
f"[RANK {rank}] ⚠️ discrete_image_token_ids out of vocab range!"
)
print(f" min={min_id}, max={max_id}, vocab_size={vocab_size}")
print(
f" discrete_image_unit_0_id={self.config.discrete_image_unit_0_id}"
)
print(
f" Expected token range: [{self.config.discrete_image_unit_0_id}, {self.config.discrete_image_unit_0_id + 65535}]"
)
print(f" Clipping to valid vocab range [0, {vocab_size-1}]...")
discrete_image_token_ids = torch.clamp(
discrete_image_token_ids, 0, vocab_size - 1
)
try:
discrete_image_embeddings = self.get_input_embeddings()(
discrete_image_token_ids
)
except RuntimeError as e:
print(
f"[RANK {rank}] 🚨 FATAL: discrete_image embedding lookup failed!"
)
print(f" Error: {e}")
print(
f" discrete_image_token_ids shape: {discrete_image_token_ids.shape}"
)
print(
f" discrete_image_token_ids min/max: {discrete_image_token_ids.min().item()}/{discrete_image_token_ids.max().item()}"
)
print(f" vocab_size: {vocab_size}")
raise
if sum(discrete_image_cnt) == 0:
inputs_embeds[0, 0:0] = discrete_image_embeddings[0, 0:0]
else:
mask = input_ids.eq(self.config.discrete_image_start_id)
positions = mask.nonzero(as_tuple=False)
batch_idx = positions[:, 0]
seq_idx = positions[:, 1]
inputs_embeds[batch_idx, seq_idx, :] = (
discrete_image_embeddings.view(
-1, discrete_image_embeddings.shape[-1]
)
)
if labels is not None:
label_mask = labels.eq(self.config.discrete_image_start_id)
label_positions = label_mask.nonzero(as_tuple=False)
if label_positions.numel() > 0:
input_mask_flat = mask.view(-1)
label_mask_flat = label_mask.view(-1)
trainable_mask_in_input_positions = label_mask_flat[
input_mask_flat
]
discrete_image_token_ids_flat = (
discrete_image_token_ids.view(-1)
)
trainable_token_ids = discrete_image_token_ids_flat[
trainable_mask_in_input_positions
]
label_batch_idx = label_positions[:, 0]
label_seq_idx = label_positions[:, 1]
labels[label_batch_idx, label_seq_idx] = trainable_token_ids
del (
discrete_unit_lists,
discrete_image_token_ids,
discrete_image_embeddings,
)
video_audio_features = None
if (
video_audio_values is not None
and video_audio_masks is not None
and self.audio_model is not None
and self.audio_projector is not None
):
is_dummy_data = isinstance(video_audio_values, torch.Tensor)
if is_dummy_data:
dummy_chunk = video_audio_values[0]
dummy_mask = video_audio_masks[0]
video_audio_values = [[[dummy_chunk]]]
video_audio_masks = [[[dummy_mask]]]
all_audio_chunks = []
all_audio_masks = []
chunks_per_video = []
for sample_video_audios, sample_video_masks in zip(
video_audio_values, video_audio_masks
):
for video_audio_chunks, video_audio_chunk_masks in zip(
sample_video_audios, sample_video_masks
):
if video_audio_chunks and len(video_audio_chunks) > 0:
chunks_per_video.append(len(video_audio_chunks))
all_audio_chunks.extend(video_audio_chunks)
all_audio_masks.extend(video_audio_chunk_masks)
if len(all_audio_chunks) > 0:
audio_values_batch = torch.stack(all_audio_chunks, dim=0).to(
inputs_embeds.device
)
audio_masks_raw = torch.stack(all_audio_masks, dim=0).to(
inputs_embeds.device
)
num_chunks = audio_values_batch.shape[0]
max_mel_seq_len = 3000
audio_feat_lengths, audio_output_lengths = (
self.audio_model._get_feat_extract_output_lengths(
audio_masks_raw.sum(-1)
)
)
if isinstance(audio_feat_lengths, list):
audio_feat_lengths = torch.tensor(
audio_feat_lengths, device=inputs_embeds.device
)
if isinstance(audio_output_lengths, list):
audio_output_lengths = torch.tensor(
audio_output_lengths, device=inputs_embeds.device
)
max_seq_len = (max_mel_seq_len - 2) // 2 + 1
seq_range = (
torch.arange(
0,
max_seq_len,
dtype=audio_feat_lengths.dtype,
device=audio_feat_lengths.device,
)
.unsqueeze(0)
.expand(num_chunks, max_seq_len)
)
lengths_expand = audio_feat_lengths.unsqueeze(1).expand(
num_chunks, max_seq_len
)
padding_mask = seq_range >= lengths_expand
audio_mask_ = padding_mask.view(
num_chunks, 1, 1, max_seq_len
).expand(num_chunks, 1, max_seq_len, max_seq_len)
video_audio_attn_masks = audio_mask_.to(
dtype=self.audio_model.conv1.weight.dtype,
device=self.audio_model.conv1.weight.device,
)
video_audio_attn_masks[audio_mask_] = float("-inf")
with torch.no_grad():
audio_chunk_features = self.audio_model(
audio_values_batch, attention_mask=video_audio_attn_masks
)
if getattr(self.config, "freeze_audio_projector", False):
with torch.no_grad():
audio_chunk_features = self.audio_projector(
audio_chunk_features.last_hidden_state
)
else:
audio_chunk_features = self.audio_projector(
audio_chunk_features.last_hidden_state
)
actual_seq_len = audio_chunk_features.shape[1]
audio_features_mask = torch.arange(
actual_seq_len, device=audio_output_lengths.device
)[None, :]
audio_features_mask = (
audio_features_mask < audio_output_lengths[:, None]
)
if self.video_audio_compressor is not None:
video_features_list = []
chunk_offset = 0
for num_chunks_in_video in chunks_per_video:
video_valid_features_list = []
for i in range(num_chunks_in_video):
chunk_idx = chunk_offset + i
chunk_mask = audio_features_mask[chunk_idx]
valid_features = audio_chunk_features[chunk_idx][
chunk_mask
]
video_valid_features_list.append(valid_features)
video_all_features = torch.cat(
video_valid_features_list, dim=0
)
video_features_list.append(video_all_features)
chunk_offset += num_chunks_in_video
del video_valid_features_list
mambamia_chunk_size = self.video_audio_compressor.chunk_size
seq_lens = [vf.shape[0] for vf in video_features_list]
num_queries_per_video = [
(sl + mambamia_chunk_size - 1) // mambamia_chunk_size
for sl in seq_lens
]
max_seq_len = max(seq_lens)
hidden_dim = video_features_list[0].shape[1]
num_videos = len(video_features_list)
batched_features = torch.zeros(
num_videos,
max_seq_len,
hidden_dim,
dtype=video_features_list[0].dtype,
device=video_features_list[0].device,
)
for vid_idx, vf in enumerate(video_features_list):
batched_features[vid_idx, : vf.shape[0], :] = vf
del video_features_list
compressed_batch = self.video_audio_compressor(batched_features)
del batched_features
compressed_features_list = []
for vid_idx, nq in enumerate(num_queries_per_video):
valid_compressed = compressed_batch[vid_idx, :nq, :]
compressed_features_list.append(valid_compressed)
video_audio_features = torch.cat(
compressed_features_list, dim=0
)
del compressed_batch, compressed_features_list
else:
pooled_features_list = []
for chunk_idx in range(num_chunks):
chunk_mask = audio_features_mask[chunk_idx]
valid_features = audio_chunk_features[chunk_idx][chunk_mask]
valid_len = valid_features.shape[0]
pool_size = 25
num_pooled = (valid_len + pool_size - 1) // pool_size
for pool_idx in range(num_pooled):
start_idx = pool_idx * pool_size
end_idx = min(start_idx + pool_size, valid_len)
pooled_feat = valid_features[start_idx:end_idx].mean(
dim=0
)
pooled_features_list.append(pooled_feat)
video_audio_features = torch.stack(pooled_features_list, dim=0)
del pooled_features_list
del (
audio_values_batch,
audio_masks_raw,
video_audio_attn_masks,
audio_chunk_features,
)
del (
seq_range,
lengths_expand,
padding_mask,
audio_mask_,
audio_features_mask,
)
if video_audio_features is not None:
video_audio_token_id = self.config.video_audio_start_id
video_audio_cnts = (
(input_ids == video_audio_token_id).sum(dim=1).tolist()
)
if sum(video_audio_cnts) == 0:
inputs_embeds[0, 0:0] = video_audio_features[0:0]
else:
mask = input_ids.eq(video_audio_token_id)
positions = mask.nonzero(as_tuple=False)
batch_idx = positions[:, 0]
seq_idx = positions[:, 1]
num_placeholder_tokens = len(batch_idx)
num_actual_features = video_audio_features.shape[0]
if num_placeholder_tokens != num_actual_features:
rank = int(os.environ.get("RANK", -1))
print(
f"\n[RANK {rank}] ⚠️ VIDEO AUDIO SHAPE MISMATCH DETECTED!"
)
print(
f" Placeholder tokens in input_ids: {num_placeholder_tokens}"
)
print(
f" Actual audio features generated: {num_actual_features}"
)
print(
f" Difference: {num_placeholder_tokens - num_actual_features}"
)
print(f" video_audio_cnts per sample: {video_audio_cnts}")
if video_audio_values is not None and not isinstance(
video_audio_values, torch.Tensor
):
print(f" video_audio_values structure:")
for sample_idx, sample_audios in enumerate(
video_audio_values
):
print(
f" Sample {sample_idx}: {len(sample_audios)} videos"
)
for vid_idx, vid_audios in enumerate(sample_audios):
if vid_audios:
print(
f" Video {vid_idx}: {len(vid_audios)} chunks, first chunk shape: {vid_audios[0].shape if len(vid_audios) > 0 else 'N/A'}"
)
else:
print(
f" Video {vid_idx}: 0 chunks (no audio)"
)
elif isinstance(video_audio_values, torch.Tensor):
print(
f" video_audio_values is dummy tensor (shape: {video_audio_values.shape})"
)
if (
video_audio_masks is not None
and audio_output_lengths is not None
):
print(
f" video_audio_masks analysis (using actual audio_output_lengths):"
)
chunk_idx_global = 0
for sample_idx, sample_masks in enumerate(
video_audio_masks
):
for vid_idx, vid_masks in enumerate(sample_masks):
for chunk_idx_local, mask_tensor in enumerate(
vid_masks
):
mask_sum = int(mask_tensor.sum())
actual_output_len = int(
audio_output_lengths[chunk_idx_global]
)
pool_size = 25
num_pooled = (
actual_output_len + pool_size - 1
) // pool_size
print(
f" Sample {sample_idx}, Video {vid_idx}, Chunk {chunk_idx_local}: mask_sum={mask_sum}, actual_output_len={actual_output_len}, pooled={num_pooled}"
)
chunk_idx_global += 1
print(f" Total chunks processed: {chunk_idx_global}")
raise RuntimeError(
f"Video audio shape mismatch: {num_placeholder_tokens} placeholder tokens "
f"vs {num_actual_features} generated features. See debug info above."
)
inputs_embeds[batch_idx, seq_idx, :] = video_audio_features
del mask, positions, batch_idx, seq_idx
del video_audio_features
if (
audio_values is not None
and len([i for i in audio_values if i is not None]) > 0
):
assert audio_masks is not None
batch_size, _, max_mel_seq_len = audio_values.size()
assert (
max_mel_seq_len == 3000
), f"max_mel_seq_len should be 3000, but got {max_mel_seq_len}"
audio_feat_lengths, audio_output_lengths = (
self.audio_model._get_feat_extract_output_lengths(
audio_masks.sum(-1)
)
)
max_seq_len = (max_mel_seq_len - 2) // 2 + 1
seq_range = (
torch.arange(
0,
max_seq_len,
dtype=audio_feat_lengths.dtype,
device=audio_feat_lengths.device,
)
.unsqueeze(0)
.expand(batch_size, max_seq_len)
)
lengths_expand = audio_feat_lengths.unsqueeze(1).expand(
batch_size, max_seq_len
)
padding_mask = seq_range >= lengths_expand
audio_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
batch_size, 1, max_seq_len, max_seq_len
)
audio_masks = audio_mask_.to(
dtype=self.audio_model.conv1.weight.dtype,
device=self.audio_model.conv1.weight.device,
)
audio_masks[audio_mask_] = float("-inf")
with torch.no_grad():
audio_features = self.audio_model(
audio_values, attention_mask=audio_masks
)
if getattr(self.config, "freeze_audio_projector", False):
with torch.no_grad():
audio_features = self.audio_projector(
audio_features.last_hidden_state
)
else:
audio_features = self.audio_projector(
audio_features.last_hidden_state
)
assert (
self.config.audio_start_id is not None
), "audio_start_id 가 정의되어 있지 않습니다"
audio_cnts = (
(input_ids == self.config.audio_start_id).sum(dim=1).tolist()
)
mask = input_ids.eq(self.config.audio_start_id)
positions = mask.nonzero(as_tuple=False)
batch_idx = positions[:, 0]
seq_idx = positions[:, 1]
if sum(audio_cnts) == 0:
inputs_embeds[0, 0:0] = audio_features[0, 0:0]
else:
num_audios, max_audio_tokens, embed_dim = audio_features.shape
audio_features_mask = torch.arange(
max_audio_tokens, device=audio_output_lengths.device
)[None, :]
audio_features_mask = (
audio_features_mask < audio_output_lengths[:, None]
)
audio_features = audio_features[audio_features_mask]
inputs_embeds[batch_idx, seq_idx, :] = audio_features
del audio_features_mask
del seq_range, lengths_expand, padding_mask, audio_mask_, audio_features
del (
mask,
positions,
batch_idx,
seq_idx,
audio_feat_lengths,
audio_output_lengths,
)
if (
discrete_audio_values is not None
and len([i for i in discrete_audio_values if i is not None]) > 0
):
if (
not hasattr(self, "discrete_audio_model")
or self.discrete_audio_model is None
):
raise ValueError(
"[BUG] discrete_audio_model이 초기화되지 않았지만 discrete_audio_values가 제공되었습니다"
)
assert (
self.config.discrete_audio_start_id is not None
), "discrete_audio_start_id 가 정의되어 있지 않습니다"
discrete_audio_cnts = (
(input_ids == self.config.discrete_audio_start_id)
.sum(dim=1)
.tolist()
)
if sum(discrete_audio_cnts) == 0:
with torch.no_grad():
discrete_unit_list = self.discrete_audio_model.forward(
discrete_audio_values[0]
)
discrete_unit_list = discrete_unit_list.squeeze().detach().cpu()
discrete_audio_token_ids = (
discrete_unit_list + self.config.discrete_audio_unit_0_id
)
discrete_audio_token_ids = discrete_audio_token_ids.to(
device=input_ids.device
).to(dtype=input_ids.dtype)
discrete_audio_embeddings = self.get_input_embeddings()(
discrete_audio_token_ids
)
inputs_embeds[0, 0:0] = discrete_audio_embeddings[0:0]
del (
discrete_unit_list,
discrete_audio_token_ids,
discrete_audio_embeddings,
)
else:
discrete_audio_value_counter = 0
for b_idx in range(discrete_audio_value_num_per_sample.shape[0]):
if discrete_audio_value_num_per_sample[b_idx] > 0:
discrete_unit_list_parts = []
for _ in range(discrete_audio_value_num_per_sample[b_idx]):
wav = discrete_audio_values[
discrete_audio_value_counter
]
wav = wav[wav != -200.0]
if wav.shape[0] == 0:
print(
f"[RANK {rank}] ❌ wav length is 0 after padding removal! This should be caught by preprocessor."
)
wav = torch.zeros(
400, device=wav.device, dtype=wav.dtype
)
if wav.shape[0] < 400:
print(
f"[RANK {rank}] ⚠️ wav too short ({wav.shape[0]} < 400)! Padding to 400. Preprocessor should catch this."
)
wav = torch.nn.functional.pad(
wav, (0, 400 - wav.shape[0])
)
if torch.isnan(wav).any() or torch.isinf(wav).any():
print(
f"[RANK {rank}] ⚠️ wav contains NaN/Inf! Clamping. Preprocessor should catch this."
)
wav = torch.nan_to_num(
wav, nan=0.0, posinf=1.0, neginf=-1.0
)
with torch.no_grad():
if wav.shape[0] > 80 * DEFAULT_SAMPLE_RATE:
chunk_size = 80 * DEFAULT_SAMPLE_RATE
min_chunk_size = (
MIN_DISCRETE_AUDIO_CHUNK_SAMPLES
)
discrete_unit_list = None
for start in range(0, wav.shape[0], chunk_size):
end = start + chunk_size
if (
end < wav.shape[0]
and wav.shape[0] - end < min_chunk_size
):
end = wav.shape[0]
chunk_wav = wav[start:end]
if (
chunk_wav.shape[0]
< MIN_DISCRETE_AUDIO_CHUNK_SAMPLES
):
raise RuntimeError(
f"[RANK {rank}] 🚨 CRITICAL BUG: chunk_wav too short "
f"({chunk_wav.shape[0]} < {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES})!\n"
f" This should NEVER happen - preprocessor asserts MIN_DISCRETE_AUDIO_CHUNK_SAMPLES.\n"
f" start={start}, end={end}, wav.shape[0]={wav.shape[0]}\n"
f" chunk_size={chunk_size}, min_chunk_size={min_chunk_size}\n"
f" Check for: 1) collate_fn padding removal bug, 2) chunking logic mismatch, "
f"3) data corruption during transfer.\n"
f" Cannot continue - skipping would cause shape mismatch downstream."
)
try:
chunk_result = (
self.discrete_audio_model.forward(
chunk_wav
)
)
chunk_result = (
chunk_result.squeeze()
.detach()
.cpu()
)
except RuntimeError as e:
print(
f"[RANK {rank}] ❌ discrete_audio_model.forward() FAILED!"
)
print(f" RuntimeError: {e}")
print(
f" chunk_wav.shape: {chunk_wav.shape}"
)
print(
f" chunk_wav.dtype: {chunk_wav.dtype}"
)
print(
f" chunk_wav.device: {chunk_wav.device}"
)
print(
f" chunk_wav range: [{chunk_wav.min():.4f}, {chunk_wav.max():.4f}]"
)
print(
f" chunk_wav has NaN: {torch.isnan(chunk_wav).any().item()}"
)
print(
f" chunk_wav has Inf: {torch.isinf(chunk_wav).any().item()}"
)
print(
f" start={start}, end={end}, original_wav_len={wav.shape[0]}"
)
raise
if chunk_result.dim() == 0:
print(
f"[RANK {rank}] ⚠️ chunk_result is 0-dim scalar! Skipping this chunk."
)
print(
f" chunk_wav.shape: {chunk_wav.shape}, start={start}, end={end}"
)
del chunk_result
continue
if chunk_result.numel() == 0:
print(
f"[RANK {rank}] ⚠️ chunk_result is empty! Skipping this chunk."
)
print(
f" chunk_wav.shape: {chunk_wav.shape}, start={start}, end={end}"
)
del chunk_result
continue
if discrete_unit_list is None:
discrete_unit_list = chunk_result
else:
discrete_unit_list = torch.cat(
[discrete_unit_list, chunk_result],
dim=-1,
)
del chunk_result
if end >= wav.shape[0]:
break
if discrete_unit_list is None:
print(
f"[RANK {rank}] ⚠️ All chunks failed for this audio! Using dummy token."
)
discrete_unit_list = torch.zeros(
1, dtype=torch.int32
)
else:
try:
discrete_unit_list = (
self.discrete_audio_model.forward(wav)
)
discrete_unit_list = (
discrete_unit_list.squeeze()
.detach()
.cpu()
)
except RuntimeError as e:
print(
f"[RANK {rank}] ❌ discrete_audio_model.forward() FAILED!"
)
print(f" RuntimeError: {e}")
print(f" wav.shape: {wav.shape}")
print(f" wav.dtype: {wav.dtype}")
print(f" wav.device: {wav.device}")
print(
f" wav range: [{wav.min():.4f}, {wav.max():.4f}]"
)
print(
f" wav has NaN: {torch.isnan(wav).any().item()}"
)
print(
f" wav has Inf: {torch.isinf(wav).any().item()}"
)
raise
if discrete_unit_list.dim() == 0:
print(
f"[RANK {rank}] ⚠️ discrete_unit_list is 0-dim scalar! Using dummy value."
)
print(f" wav.shape: {wav.shape}")
discrete_unit_list = (
discrete_unit_list.unsqueeze(0)
)
if discrete_unit_list.numel() == 0:
print(
f"[RANK {rank}] ⚠️ discrete_unit_list is empty! Using dummy token."
)
print(f" wav.shape: {wav.shape}")
discrete_unit_list = torch.zeros(
1, dtype=torch.int32
)
discrete_unit_list_parts.append(discrete_unit_list)
discrete_audio_value_counter += 1
if len(discrete_unit_list_parts) == 0:
print(
f"[RANK {rank}] ⚠️ All discrete audio chunks were invalid/empty for batch {b_idx}!"
)
print(
f" discrete_audio_value_num_per_sample[{b_idx}] = {discrete_audio_value_num_per_sample[b_idx]}"
)
print(
f" Creating dummy discrete_unit_lists to match placeholder count."
)
discrete_audio_start_positions = (
input_ids[b_idx]
== self.config.discrete_audio_start_id
).nonzero(as_tuple=True)[0]
num_placeholders = len(discrete_audio_start_positions)
discrete_unit_lists = torch.zeros(
num_placeholders, dtype=torch.int32
)
print(
f"[RANK {rank}] Created {num_placeholders} dummy tokens (all zeros)."
)
else:
for idx, part in enumerate(discrete_unit_list_parts):
if part.numel() == 0:
print(
f"[RANK {rank}] ⚠️ discrete_unit_list_parts[{idx}] is empty tensor!"
)
if (
torch.isnan(part.float()).any()
or torch.isinf(part.float()).any()
):
print(
f"[RANK {rank}] ⚠️ discrete_unit_list_parts[{idx}] contains NaN/Inf!"
)
part = torch.nan_to_num(
part.float(),
nan=0.0,
posinf=6560.0,
neginf=0.0,
).to(dtype=part.dtype)
discrete_unit_lists = torch.cat(
discrete_unit_list_parts
)
del discrete_unit_list_parts
if (discrete_unit_lists < 0).any() or (
discrete_unit_lists > 6561
).any():
print(
f"[RANK {rank}] ❌ discrete_unit_lists has invalid values!"
)
print(
f" min={discrete_unit_lists.min().item()}, max={discrete_unit_lists.max().item()}"
)
print(
f" Expected range: [0, 6561] (tokenizer has audio0000~audio6561, 6562 tokens)"
)
print(
f" FSQ mathematically max is 6560, but allowing 6561 for tokenizer compatibility"
)
print(f" Clamping to valid range...")
discrete_unit_lists = torch.clamp(
discrete_unit_lists, 0, 6561
)
discrete_audio_token_ids = (
discrete_unit_lists
+ self.config.discrete_audio_unit_0_id
)
discrete_audio_token_ids = discrete_audio_token_ids.to(
device=input_ids.device
).to(dtype=input_ids.dtype)
discrete_audio_start_positions = (
input_ids[b_idx] == self.config.discrete_audio_start_id
).nonzero(as_tuple=True)[0]
assert (
len(discrete_audio_start_positions) > 0
), "discrete_audio_start_id가 input_ids에 존재하지 않습니다"
discrete_audio_embeddings = self.get_input_embeddings()(
discrete_audio_token_ids
)
if len(discrete_audio_start_positions) > 0:
inputs_embeds[
b_idx, discrete_audio_start_positions, :
] = discrete_audio_embeddings
if (
labels is not None
and len(discrete_audio_start_positions) > 0
):
label_mask_audio = labels[b_idx].eq(
self.config.discrete_audio_start_id
)
if label_mask_audio.any():
input_mask_audio = (
input_ids[b_idx]
== self.config.discrete_audio_start_id
)
trainable_mask_in_input_positions_audio = (
label_mask_audio[input_mask_audio]
)
trainable_audio_token_ids = (
discrete_audio_token_ids[
trainable_mask_in_input_positions_audio
]
)
label_positions_audio = label_mask_audio.nonzero(
as_tuple=True
)[0]
labels[b_idx, label_positions_audio] = (
trainable_audio_token_ids
)
return inputs_embeds, labels
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
**kwargs,
):
model = super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
**kwargs,
)
model.tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True
)
return model
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
*args,
**kwargs,
):
super().register_for_auto_class("AutoModel")
self.config.register_for_auto_class()
super().save_pretrained(save_directory, *args, **kwargs)
@torch.no_grad()
def inference(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[
Union[List[List[torch.FloatTensor]], torch.FloatTensor]
] = None,
image_sizes: Optional[List[List[List[int]]]] = None,
mm_query_lengths: Optional[List[List[int]]] = None,
non_mm_query_lengths: Optional[List[int]] = None,
num_queries_vis_abstractors: Optional[List[List[int]]] = None,
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None,
first_last_frames_slows: Optional[List[List[bool]]] = None,
is_videos: Optional[List[List[bool]]] = None,
img_start_ids_list: Optional[List[List[int]]] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
audio_values: Optional[torch.FloatTensor] = None,
discrete_audio_values: Optional[torch.FloatTensor] = None,
discrete_audio_value_num_per_sample: Optional[torch.LongTensor] = None,
audio_masks: Optional[torch.LongTensor] = None,
max_length: int = 196,
min_length: int = 2,
do_sample: bool = True,
num_beams: int = 1,
top_p: float = 0.6,
top_k: int = 0,
temperature: float = 0.5,
repetition_penalty: float = 1.0,
length_penalty: int = 1,
early_stopping: Union[bool, str] = False,
use_cache: bool = True,
verbose: bool = False,
**kwargs,
):
"""
:param input_ids: torch.int64 : torch.size([batchsize, variable)]) : SystemPrompt with Question text token indices for tokenizer.
In positions where images are inputted, the value is replaced by config.img_start_id, which is a vocabulary index used to indicate the start of image data.
In cases where a sample contains no images, a single dummy image is included.
:param pixel_values: List of List of 4D tensor (torch.float32)
Each outer list corresponds to a batch and contains inner lists, each holding tensors for images in a sample. The structure accounts for samples with multiple images.
:param attention_mask: not used
:param max_length: int : The maximum length the generated tokens can have. Corresponds to the length of the input prompt + max_new_tokens.
:param min_length: int : The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + min_new_tokens.
:param num_beams: int : Number of beams for beam search. 1 means no beam search.
:param top_k: int : The number of highest probability vocabulary tokens to keep for top-k-filtering.
:param temperature: float : The value used to modulate the next token probabilities. ( scores / self.temperature )
:param repetition_penalty: float : The parameter for repetition penalty.
:param length_penalty: int : It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence.
:param early_stopping: Union[bool, str] : True, where the generation stops as soon as there are num_beams complete candidates;
False, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates;
"never", where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm)
:param use_cache: bool : Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
:param verbose: bool : print debug mention
:param image_sizes: Stacked as a List of List, representing image sizes (width, height).
In cases where a sample contains no images, a single dummy image is included.
:param mm_query_lengths: A List of List that stores the lengths when each image is converted into visual tokens for LLM input.
In cases where a sample does not contain any images, an empty list is included.
:param non_mm_query_lengths: contains the lengths of text tokens (excluding visual tokens) for each sample in a batch.
:param num_queries_vis_abstractors: A List of List that contains the number of visual tokens for each image grid.
:param num_queries_vis_abstractors_slow: A List of List that contains the number of visual tokens for the slow part when applying the slowfast algorithm to video frames. If the slowfast algorithm is not applied, it will have a value of None.
:param first_last_frames_slows: A List of List that stores the only first and last frames slow mode for each sample in a batch.
:param is_videos: A List of List that stores the boolean value indicating whether each sample in a batch is a video.
:image_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
:pixel_values_videos: A 2D tensor (torch.float32) for qwen2.5-vl visual encoder.
:video_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
:param kwargs:
:return:
"""
inputs_embeds, _ = self.extract_inputs_embeds(
input_ids=input_ids,
pixel_values=self.to_vision_model_device(pixel_values),
image_sizes=image_sizes,
mm_query_lengths=mm_query_lengths,
non_mm_query_lengths=non_mm_query_lengths,
img_start_ids_list=img_start_ids_list,
num_queries_vis_abstractors=num_queries_vis_abstractors,
num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow,
first_last_frames_slows=first_last_frames_slows,
is_videos=is_videos,
image_grid_thw=image_grid_thw,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
audio_values=audio_values,
discrete_audio_values=discrete_audio_values,
discrete_audio_value_num_per_sample=discrete_audio_value_num_per_sample,
audio_masks=audio_masks,
)
if self.without_llm:
inputs_embeds = (
inputs_embeds.to(self.vision_model.device)
if isinstance(inputs_embeds, torch.Tensor)
else inputs_embeds
)
return inputs_embeds
inputs_embeds = (
inputs_embeds.to(self.base_model.device)
if isinstance(inputs_embeds, torch.Tensor)
else inputs_embeds
)
pred = self.language_model.generate(
inputs_embeds=inputs_embeds,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.encode("<|im_end|>")[0],
bad_words_ids=[
[
self.config.text_config.bos_token_id,
],
[
self.config.text_config.eos_token_id,
],
],
max_new_tokens=max_length,
min_length=min_length,
num_beams=num_beams,
do_sample=False if temperature == 0.0 else do_sample,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
early_stopping=False if num_beams <= 1 else True,
use_cache=use_cache,
no_repeat_ngram_size=20,
)
if verbose:
llm_query = self.tokenizer.batch_decode(
[
[
token_id
for token_id in input_ids_row
if token_id != self.tokenizer.pad_token_id
]
for input_ids_row in input_ids.detach().cpu().tolist()
],
skip_special_tokens=False,
)[0]
llm_pred = self.tokenizer.batch_decode(
[
[
token_id
for token_id in pred_row
if token_id != self.tokenizer.pad_token_id
]
for pred_row in pred.detach().cpu().tolist()
],
skip_special_tokens=False,
)[0]
print(f"# [info] llm_query: {llm_query}")
print(f"# [info] llm_pred: {llm_pred}")
return pred
def to_vision_model_device(self, input_tensor):
if isinstance(input_tensor, list):
return [self.to_vision_model_device(item) for item in input_tensor]
elif isinstance(input_tensor, torch.Tensor):
return input_tensor.to(self.vision_model.device)
else:
raise TypeError(
"Unsupported data type. Only tensors and lists are allowed."
)
@classmethod
def from_config(cls, config, vision_model_name_or_path):
return cls(config, vision_model_name_or_path)
def get_language_model(self):
return self.language_model.base_model
def get_vision_model(self):
return self.vision_model
def compute_adaptive_params(
self,
pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
num_queries_vis_abstractors: Optional[List[List[int]]] = None,
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None,
image_sizes: Optional[List[List[List[int]]]] = None,
is_videos: Optional[List[List[bool]]] = None,
first_last_frames_slows: Optional[List[List[bool]]] = None,
):
assert all(
all(isinstance(value, int) and value >= 0 for value in sublist)
for sublist in num_queries_vis_abstractors
), "All values in num_queries_vis_abstractors must be integers >= 0."
assert all(
all(isinstance(value, int) and value >= 0 for value in sublist)
for sublist in num_queries_vis_abstractors_slow
), "All values in num_queries_vis_abstractors_slow must be integers >= 0."
assert is_videos is not None
is_first_images = []
is_last_images = []
for is_video in is_videos:
for idx, is_video_item in enumerate(is_video):
if idx == 0:
is_first_images.append(True)
else:
is_first_images.append(False)
if idx == len(is_video) - 1:
is_last_images.append(True)
else:
is_last_images.append(False)
num_queries_vis_abstractors = list(chain(*num_queries_vis_abstractors))
num_queries_vis_abstractors_slow = list(
chain(*num_queries_vis_abstractors_slow)
)
image_sizes = list(chain(*image_sizes))
is_videos = list(chain(*is_videos))
first_last_frames_slows = list(chain(*first_last_frames_slows))
use_slowfast = any(
[num_query > 0 for num_query in num_queries_vis_abstractors_slow]
)
num_grids = [pixel_value.shape[0] for pixel_value in chain(*pixel_values)]
num_grids = [0] + num_grids
group_ids = []
if use_slowfast:
new_num_grids = [num_grids[0]]
new_num_queries = []
new_image_sizes = []
new_is_videos = []
for (
num_query,
num_query_slow,
num_grid,
image_size,
is_video,
first_last_frames_slow,
is_first_image,
is_last_image,
) in zip(
num_queries_vis_abstractors,
num_queries_vis_abstractors_slow,
num_grids[1:],
image_sizes,
is_videos,
first_last_frames_slows,
is_first_images,
is_last_images,
):
if not first_last_frames_slow and num_query_slow > 0:
assert is_video is True
this_group_ids = [group_ids[-1][-1] + 1 if group_ids else 0]
new_num_grids.append(new_num_grids[-1] + 1)
new_num_queries.append(num_query_slow)
new_image_sizes.append(image_size)
new_is_videos.append(is_video)
if num_grid >= 2:
new_num_grids.append(new_num_grids[-1] + num_grid - 1)
new_num_queries.append(num_query)
new_image_sizes.append(image_size)
new_is_videos.append(is_video)
this_group_ids.append(this_group_ids[-1] + 1)
group_ids.append(this_group_ids)
elif (
first_last_frames_slow
and num_query_slow > 0
and (is_first_image or is_last_image)
):
assert is_video is True
this_group_ids = [group_ids[-1][-1] + 1 if group_ids else 0]
if num_grid == 1:
new_num_grids.append(new_num_grids[-1] + 1)
new_num_queries.append(num_query_slow)
new_image_sizes.append(image_size)
new_is_videos.append(is_video)
if num_grid >= 2:
if is_first_image:
new_num_grids.append(new_num_grids[-1] + 1)
new_num_queries.append(num_query_slow)
new_image_sizes.append(image_size)
new_is_videos.append(is_video)
new_num_grids.append(new_num_grids[-1] + num_grid - 1)
new_num_queries.append(num_query)
new_image_sizes.append(image_size)
new_is_videos.append(is_video)
this_group_ids.append(this_group_ids[-1] + 1)
elif is_last_image:
new_num_grids.append(new_num_grids[-1] + num_grid - 1)
new_num_queries.append(num_query)
new_image_sizes.append(image_size)
new_is_videos.append(is_video)
new_num_grids.append(new_num_grids[-1] + 1)
new_num_queries.append(num_query_slow)
new_image_sizes.append(image_size)
new_is_videos.append(is_video)
this_group_ids.append(this_group_ids[-1] + 1)
else:
raise Exception("This case should not be reached.")
group_ids.append(this_group_ids)
else:
new_num_grids.append(new_num_grids[-1] + num_grid)
new_num_queries.append(num_query)
new_image_sizes.append(image_size)
new_is_videos.append(is_video)
start_group_id = group_ids[-1][-1] + 1 if group_ids else 0
group_ids.append([start_group_id])
num_grids = new_num_grids
num_queries_vis_abstractors = new_num_queries
image_sizes = new_image_sizes
is_videos = new_is_videos
else:
num_grids = [sum(num_grids[:i]) for i in range(1, len(num_grids) + 1)]
group_ids = [[group_id] for group_id in range(len(is_videos))]
return num_queries_vis_abstractors, num_grids, image_sizes, is_videos, group_ids
def split_adaptive_params(
self, num_queries_vis_abstractors, num_grids, chunk_size: int, n_chunks: int
):
"""
num_grids/num_queries 를 chunk_size 단위로 최대 n_chunks 만큼 자른다.
실제 데이터가 부족하면 남은 chunk 는 더미([0,1]) 로 채운다.
Returns
-------
chunk_qs : List[List[int]]
chunk_grids: List[List[int]]
각 원소 길이는 동일하며, 전체 길이는 정확히 n_chunks.
"""
total_len = num_grids[-1]
chunk_qs, chunk_grids, is_splits = [], [], []
slices = list(zip(num_grids[:-1], num_grids[1:], num_queries_vis_abstractors))
slice_idx = 0
for chunk_idx in range(n_chunks):
start = chunk_idx * chunk_size
end = start + chunk_size
if start >= total_len:
chunk_grids.append([0, 1])
chunk_qs.append([num_queries_vis_abstractors[-1]])
is_splits.append(False)
continue
grids_in_chunk = [0]
qs_in_chunk = []
while slice_idx < len(slices) and slices[slice_idx][1] <= start:
slice_idx += 1
is_split = False
j = slice_idx
while j < len(slices) and slices[j][0] < end:
s, e, q = slices[j]
left = max(s, start)
right = min(e, end)
off = right - start
if off not in grids_in_chunk:
grids_in_chunk.append(off)
qs_in_chunk.append(q)
if right == end and e != end:
is_split = True
if e > end:
break
j += 1
slice_idx = j
final_off = min(end, total_len) - start
if grids_in_chunk[-1] != final_off:
grids_in_chunk.append(final_off)
qs_in_chunk.append(
qs_in_chunk[-1] if qs_in_chunk else num_queries_vis_abstractors[-1]
)
is_split = True
chunk_grids.append(grids_in_chunk)
chunk_qs.append(qs_in_chunk)
is_splits.append(is_split)
return chunk_qs, chunk_grids, is_splits
class HCXVisionForCausalLM(HCXVisionPreTrainedModel, GenerationMixin):
def __init__(
self,
config: HCXVisionConfig,
without_llm=False,
**kwargs,
):
super().__init__(config, without_llm=without_llm, **kwargs)
text_config = config.get_text_config()
self.model = HCXVisionModel(config=config, **kwargs)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[List[List[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True,
image_sizes: Optional[List[List[List[int]]]] = None,
vision_query_lengths: Optional[List[List[int]]] = None,
non_vision_query_lengths: Optional[List[List[int]]] = None,
img_start_ids_list: Optional[List[List[int]]] = None,
num_queries_vis_abstractors: Optional[List[List[int]]] = None,
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None,
first_last_frames_slows: Optional[List[List[bool]]] = None,
is_videos: Optional[List[List[bool]]] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
:param input_ids: torch.int64 : torch.size([batchsize, variable)]) : SystemPrompt with Question text token indices for tokenizer.
In positions where images are inputted, the value is replaced by config.img_start_id, which is a vocabulary index used to indicate the start of image data.
:param pixel_values: List of List of 4D tensor (torch.float32)
Each outer list corresponds to a batch and contains inner lists, each holding tensors for images in a sample. The structure accounts for samples with multiple images.
:param past_key_values: None
:param inputs_embeds: None
:param labels: Optional[torch.int64] : [batchsize, variable (input_ids.size(1)+ num visual tokens)] visual token 들은 모두 IGNORE_INDEX
:param use_cache: None
:param output_attentions: Optional[bool] : get attention weights of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함)
:param output_hidden_states: Optional[bool] : get hidden states of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함)
:param image_sizes: Stacked as a List of List, representing image sizes (width, height).
In cases where a sample contains no images, a single dummy image is included.
:param vision_query_lengths: A List of List that stores the lengths when each image is converted into visual tokens for LLM input.
In cases where a sample does not contain any images, an empty list is included.
:param non_vision_query_lengths: contains the lengths of text tokens (excluding visual tokens) for each sample in a batch.
:img_start_ids_list: contains the indices of the img_start_id tokens for each sample.
:num_queries_vis_abstractors: A List of List that contains the number of visual tokens for each image grid.
:num_queries_vis_abstractors_slow: A List of List that contains the number of visual tokens for the slow part when applying the slowfast algorithm to video frames. If the slowfast algorithm is not applied, it will have a value of None.
:first_last_frames_slows: A List of List that contains the only first and last frames slow mode for each sample in a batch.
:is_videos: A List of List that contains the boolean value indicating whether each sample in a batch is a video.
:image_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
:pixel_values_videos: A 2D tensor (torch.float32) for qwen2.5-vl visual encoder.
:video_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
:return:
"""
loss = None
logits = None
outputs = self.model.forward(
input_ids=input_ids,
pixel_values=pixel_values,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
image_sizes=image_sizes,
vision_query_lengths=vision_query_lengths,
non_vision_query_lengths=non_vision_query_lengths,
img_start_ids_list=img_start_ids_list,
num_queries_vis_abstractors=num_queries_vis_abstractors,
num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow,
first_last_frames_slows=first_last_frames_slows,
is_videos=is_videos,
image_grid_thw=image_grid_thw,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
)
hidden_states = outputs.last_hidden_state
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.model.language_model.lm_head(
hidden_states[:, slice_indices, :]
) * getattr(self.config.text_config, "logits_scaling", 1)
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.text_config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@torch.no_grad()
def inference(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[
Union[List[List[torch.FloatTensor]], torch.FloatTensor]
] = None,
image_sizes: Optional[List[List[List[int]]]] = None,
vision_query_lengths: Optional[List[List[int]]] = None,
non_vision_query_lengths: Optional[List[int]] = None,
num_queries_vis_abstractors: Optional[List[List[int]]] = None,
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None,
first_last_frames_slows: Optional[List[List[bool]]] = None,
is_videos: Optional[List[List[bool]]] = None,
img_start_ids_list: Optional[List[List[int]]] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
max_length: int = 196,
min_length: int = 2,
do_sample: bool = True,
num_beams: int = 1,
top_p: float = 0.6,
top_k: int = 0,
temperature: float = 0.5,
repetition_penalty: float = 1.0,
length_penalty: int = 1,
early_stopping: Union[bool, str] = False,
use_cache: bool = True,
**kwargs,
):
"""
:param input_ids: torch.int64 : torch.size([batchsize, variable)]) : SystemPrompt with Question text token indices for tokenizer.
In positions where images are inputted, the value is replaced by config.img_start_id, which is a vocabulary index used to indicate the start of image data.
In cases where a sample contains no images, a single dummy image is included.
:param pixel_values: List of List of 4D tensor (torch.float32)
Each outer list corresponds to a batch and contains inner lists, each holding tensors for images in a sample. The structure accounts for samples with multiple images.
:param attention_mask: not used
:param max_length: int : The maximum length the generated tokens can have. Corresponds to the length of the input prompt + max_new_tokens.
:param min_length: int : The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + min_new_tokens.
:param num_beams: int : Number of beams for beam search. 1 means no beam search.
:param top_k: int : The number of highest probability vocabulary tokens to keep for top-k-filtering.
:param temperature: float : The value used to modulate the next token probabilities. ( scores / self.temperature )
:param repetition_penalty: float : The parameter for repetition penalty.
:param length_penalty: int : It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence.
:param early_stopping: Union[bool, str] : True, where the generation stops as soon as there are num_beams complete candidates;
False, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates;
"never", where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm)
:param use_cache: bool : Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
:param verbose: bool : print debug mention
:param image_sizes: Stacked as a List of List, representing image sizes (width, height).
In cases where a sample contains no images, a single dummy image is included.
:param vision_query_lengths: A List of List that stores the lengths when each image is converted into visual tokens for LLM input.
In cases where a sample does not contain any images, an empty list is included.
:param non_vision_query_lengths: contains the lengths of text tokens (excluding visual tokens) for each sample in a batch.
:param num_queries_vis_abstractors: A List of List that contains the number of visual tokens for each image grid.
:param num_queries_vis_abstractors_slow: A List of List that contains the number of visual tokens for the slow part when applying the slowfast algorithm to video frames. If the slowfast algorithm is not applied, it will have a value of None.
:param first_last_frames_slows: A List of List that stores the only first and last frames slow mode for each sample in a batch.
:param is_videos: A List of List that stores the boolean value indicating whether each sample in a batch is a video.
:image_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
:pixel_values_videos: A 2D tensor (torch.float32) for qwen2.5-vl visual encoder.
:video_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder.
:param kwargs:
:return:
"""
inputs_embeds, _ = self.model.extract_inputs_embeds(
input_ids=input_ids,
pixel_values=self.to_vision_model_device(pixel_values),
image_sizes=image_sizes,
vision_query_lengths=vision_query_lengths,
non_vision_query_lengths=non_vision_query_lengths,
img_start_ids_list=img_start_ids_list,
num_queries_vis_abstractors=num_queries_vis_abstractors,
num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow,
first_last_frames_slows=first_last_frames_slows,
is_videos=is_videos,
image_grid_thw=image_grid_thw,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
)
if self.without_llm:
inputs_embeds = (
inputs_embeds.to(self.vision_model.device)
if isinstance(inputs_embeds, torch.Tensor)
else inputs_embeds
)
return inputs_embeds
inputs_embeds = (
inputs_embeds.to(self.base_model.device)
if isinstance(inputs_embeds, torch.Tensor)
else inputs_embeds
)
pred = self.language_model.generate(
inputs_embeds=inputs_embeds,
pad_token_id=self.config.text_config.pad_token_id,
eos_token_id=self.config.text_config.eos_token_id,
bad_words_ids=[
[
self.config.text_config.bos_token_id,
],
[
self.config.text_config.eos_token_id,
],
],
max_new_tokens=max_length,
min_length=min_length,
num_beams=num_beams,
do_sample=False if temperature == 0.0 else do_sample,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
early_stopping=False if num_beams <= 1 else True,
use_cache=use_cache,
)
return pred
def to_vision_model_device(self, input_tensor):
if isinstance(input_tensor, list):
return [self.to_vision_model_device(item) for item in input_tensor]
elif isinstance(input_tensor, torch.Tensor):
return input_tensor.to(self.vision_model.device)
else:
raise TypeError(
"Unsupported data type. Only tensors and lists are allowed."
)
def get_input_embeddings(self):
if self.without_llm:
return None
else:
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
if self.without_llm:
return None
else:
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
if self.without_llm:
return None
else:
return self.language_model.tie_weights()
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
**kwargs,
):
model = super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
**kwargs,
)
model.tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True
)
return model
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
*args,
**kwargs,
):
super().register_for_auto_class("AutoModelForCausalLM")
self.config.register_for_auto_class()
super().save_pretrained(save_directory, *args, **kwargs)
self.config.architectures = ["HCXVisionV2ForCausalLM"]
self.config.auto_map["AutoModelForCausalLM"] = (
"modeling_vlm.HCXVisionForCausalLM"
)
self.config.auto_map["AutoModelForSequenceClassification"] = (
"modeling_vlm.HCXVisionForSequenceClassification"
)
self.config.save_pretrained(save_directory)
@property
def is_qwen_visual(self):
return self.model.is_qwen_visual
@property
def language_model(self):
return self.model.language_model
@property
def vision_model(self):
return self.model.vision_model
@property
def discrete_vision_model(self):
return self.model.discrete_vision_model
@property
def audio_model(self):
return self.model.audio_model
@property
def discrete_audio_model(self):
return self.model.discrete_audio_model
@property
def text_config(self):
return self.model.text_config
@property
def vision_config(self):
return self.model.vision_config
@property
def audio_config(self):
return self.model.audio_config
@property
def mm_projector(self):
return self.model.mm_projector
@property
def audio_projector(self):
return self.model.audio_projector
@property
def anyres(self):
return self.model.anyres
@property
def is_safetensor_save(self):
return self.model.is_safetensor_save
@property
def without_llm(self):
return self.model.without_llm
@property
def image_newline(self):
return self.model.image_newline
class HCXVisionForSequenceClassification(HCXVisionPreTrainedModel):
"""
HCX Vision model for sequence classification tasks.
"""
def __init__(self, config, **kwargs):
super().__init__(config, without_llm=True, **kwargs)
self.num_labels = config.num_labels if hasattr(config, "num_labels") else 2
self.model = HCXVisionModel(config=config, **kwargs)
self.score = nn.Linear(
config.text_config.hidden_size, self.num_labels, bias=False
)
self.post_init()
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True,
image_sizes: Optional[List[List[List[int]]]] = None,
vision_query_lengths: Optional[List[List[int]]] = None,
non_vision_query_lengths: Optional[List[List[int]]] = None,
img_start_ids_list: Optional[List[List[int]]] = None,
num_queries_vis_abstractors: Optional[List[List[int]]] = None,
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None,
first_last_frames_slows: Optional[List[List[bool]]] = None,
is_videos: Optional[List[List[bool]]] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
) -> SequenceClassifierOutputWithPast:
"""
Forward pass for sequence classification.
"""
transformer_outputs: BaseModelOutputWithPast = self.model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
image_sizes=image_sizes,
vision_query_lengths=vision_query_lengths,
non_vision_query_lengths=non_vision_query_lengths,
img_start_ids_list=img_start_ids_list,
num_queries_vis_abstractors=num_queries_vis_abstractors,
num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow,
first_last_frames_slows=first_last_frames_slows,
is_videos=is_videos,
image_grid_thw=image_grid_thw,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thw,
)
hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError(
"Cannot handle batch sizes > 1 if no padding token is defined."
)
if self.config.pad_token_id is None:
last_non_pad_token = -1
elif input_ids is not None:
non_pad_mask = (input_ids != self.config.pad_token_id).to(
logits.device, torch.int32
)
token_indices = torch.arange(
input_ids.shape[-1], device=logits.device, dtype=torch.int32
)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
last_non_pad_token = -1
pooled_logits = logits[
torch.arange(batch_size, device=logits.device), last_non_pad_token
]
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
pooled_logits=pooled_logits,
config=self.config,
)
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
*args,
**kwargs,
):
super().register_for_auto_class("AutoModelForSequenceClassification")
self.config.register_for_auto_class()
super().save_pretrained(save_directory, *args, **kwargs)
class VLM_Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
mm_projector_type,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.mm_projector_type = mm_projector_type
if self.mm_projector_type == "mlp":
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
elif self.mm_projector_type == "inverted_mlp":
self.fc1 = nn.Linear(in_features, 2 * hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(2 * hidden_features, out_features)
else:
raise NotImplementedError(
"{} is not implemented".format(self.mm_projector_type)
)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
class Projector(nn.Module):
"""Base projector class"""
def __init__(
self,
num_queries: int,
num_input_tokens: int,
encoder_hidden_size: int,
hidden_size: int,
output_hidden_size: int,
pos_emb=True,
prenorm=False,
):
super().__init__()
self.num_input_tokens = num_input_tokens
self.output_hidden_size = output_hidden_size
if pos_emb:
self.pos_emb = torch.nn.Parameter(
torch.zeros(1, num_input_tokens, encoder_hidden_size)
)
self.pos_emb.data.normal_(mean=0.0, std=0.02)
else:
self.pos_emb = None
if prenorm:
self.prenorm = LayerNorm(encoder_hidden_size)
else:
self.prenorm = None
self.build_net(
num_queries, encoder_hidden_size, hidden_size, output_hidden_size
)
def build_net(self):
raise NotImplementedError()
def _forward(
self,
x,
num_queries_vis_abstractors: Optional[List[int]] = None,
num_grids: Optional[List[int]] = None,
freeze_before_sampler: bool = False,
):
raise NotImplementedError()
def forward(
self,
x: torch.Tensor,
num_queries_vis_abstractors: Optional[List[int]] = None,
num_grids: Optional[List[int]] = None,
freeze_before_sampler: bool = False,
) -> torch.Tensor:
"""
Args:
x: (B, L, encoder_hidden_size) tensor from the visual backbone (CLIP visual encoder), including cls token.
"""
if self.prenorm is not None:
x = self.prenorm(x)
if self.pos_emb is not None:
x = x + self.pos_emb
x = self._forward(
x,
num_queries_vis_abstractors=num_queries_vis_abstractors,
num_grids=num_grids,
freeze_before_sampler=freeze_before_sampler,
)
return x
class ConvProjector(Projector):
def _forward(
self,
x,
num_queries_vis_abstractors: Optional[List[int]] = None,
num_grids: Optional[List[int]] = None,
freeze_before_sampler: bool = False,
):
hw = int(x.size(1) ** 0.5)
x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)
if num_queries_vis_abstractors is not None:
assert num_grids is not None
return self._forward_adaptive_num_query(
x, num_queries_vis_abstractors, num_grids, freeze_before_sampler
)
if freeze_before_sampler:
with torch.no_grad():
x = self.net[0](x)
x = self.net[1](x)
x = self.net[2](x)
else:
x = self.net(x)
x = rearrange(x, "b d h w -> b (h w) d")
x = self.readout(x)
return x
def _forward_adaptive_num_query(
self,
x,
num_queries_vis_abstractors: Optional[List[int]] = None,
num_grids: Optional[List[int]] = None,
freeze_before_sampler: bool = False,
):
assert len(self.net) == 3
if freeze_before_sampler:
with torch.no_grad():
x = self.net[0](x)
else:
x = self.net[0](x)
new_x = []
for i, num_queries in enumerate(num_queries_vis_abstractors):
hw = int(num_queries**0.5)
sampler = nn.AdaptiveAvgPool2d((hw, hw))
out = sampler(x[num_grids[i] : num_grids[i + 1], :])
out = self.net[2](out)
out = rearrange(out, "b d h w -> b (h w) d")
out = self.readout(out)
new_x.append(out)
return new_x
class CAbstractor(ConvProjector):
"""C-Abstractor"""
def build_net(
self,
n_queries,
encoder_hidden_size,
hidden_size,
output_hidden_size,
depth=3,
mlp_depth=2,
):
assert (n_queries**0.5).is_integer(), "n_queries must be square number"
hw = int(n_queries**0.5)
RegBlock = partial(
RegStage,
stride=1,
dilation=1,
act_layer=nn.SiLU,
norm_layer=LayerNorm2d,
)
s1 = RegBlock(
depth,
encoder_hidden_size,
hidden_size,
)
sampler = nn.AdaptiveAvgPool2d((hw, hw))
s2 = RegBlock(
depth,
hidden_size,
hidden_size,
)
self.net = nn.Sequential(s1, sampler, s2)
self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size)
def build_mlp(self, depth, hidden_size, output_hidden_size):
layers = [nn.Linear(hidden_size, output_hidden_size)]
for _ in range(1, depth):
layers.append(nn.SiLU())
layers.append(nn.Linear(output_hidden_size, output_hidden_size))
return nn.Sequential(*layers)
AutoConfig.register("vlm", HCXVisionConfig)
try:
from .configuration_hyperclovax import HyperCLOVAXConfig
from .modeling_hyperclovax import HyperCLOVAXForCausalLM
AutoConfig.register("hyperclovax", HyperCLOVAXConfig)
AutoModelForCausalLM.register(
HyperCLOVAXConfig,
HyperCLOVAXForCausalLM,
)
except:
pass