|
|
import contextlib |
|
|
import math |
|
|
import os |
|
|
from functools import partial |
|
|
from itertools import chain |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.nn as nn |
|
|
|
|
|
try: |
|
|
from einops import rearrange |
|
|
from timm.layers import LayerNorm, LayerNorm2d |
|
|
from timm.models.regnet import RegStage |
|
|
except: |
|
|
print("packages needed for anyres are not imported") |
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
AutoModel, |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
PreTrainedModel, |
|
|
) |
|
|
from transformers.cache_utils import Cache |
|
|
from transformers.generation import GenerationMixin |
|
|
from transformers.modeling_outputs import ( |
|
|
BaseModelOutputWithPast, |
|
|
CausalLMOutputWithPast, |
|
|
SequenceClassifierOutputWithPast, |
|
|
TokenClassifierOutput |
|
|
) |
|
|
from transformers.modeling_utils import no_init_weights |
|
|
|
|
|
from .configuration_vlm import HCXVisionConfig |
|
|
|
|
|
|
|
|
def get_rank(): |
|
|
if dist.is_initialized(): |
|
|
return dist.get_rank() |
|
|
return 0 |
|
|
|
|
|
|
|
|
def is_ampere_or_newer(): |
|
|
if not torch.cuda.is_available(): |
|
|
return False |
|
|
|
|
|
gpu_name = torch.cuda.get_device_name() |
|
|
|
|
|
ampere_keywords = [ |
|
|
"RTX 30", |
|
|
"RTX 40", |
|
|
"A100", |
|
|
"H100", |
|
|
"A6000", |
|
|
"A5000", |
|
|
"A4000", |
|
|
"A3000", |
|
|
"A2000", |
|
|
"A1000", |
|
|
] |
|
|
|
|
|
return any(keyword in gpu_name for keyword in ampere_keywords) |
|
|
|
|
|
|
|
|
EOT = "<|endofturn|>" |
|
|
IMG_LOC = "<|IMAGE_PAD|>" |
|
|
|
|
|
|
|
|
|
|
|
class HCXVisionPreTrainedModel(PreTrainedModel): |
|
|
config_class = HCXVisionConfig |
|
|
base_model_prefix = "model" |
|
|
vision_model_name = "vision_model" |
|
|
_no_split_modules = [ |
|
|
"CLIPAttention", |
|
|
"SiglipVisionModel", |
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
supports_gradient_checkpointing = True |
|
|
_skip_keys_device_placement = "past_key_values" |
|
|
_supports_flash_attn_2 = True |
|
|
_supports_sdpa = True |
|
|
_supports_flex_attn = True |
|
|
_supports_cache_class = True |
|
|
_supports_quantized_cache = True |
|
|
_supports_static_cache = True |
|
|
_supports_attention_backend = True |
|
|
|
|
|
def _init_weights(self, module): |
|
|
|
|
|
if ( |
|
|
isinstance(module, nn.Conv2d) |
|
|
or isinstance(module, nn.Embedding) |
|
|
or isinstance(module, nn.Linear) |
|
|
): |
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
if hasattr(module, "bias") and module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
|
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
elif isinstance(module, nn.Parameter): |
|
|
embed_std = 1 / torch.sqrt(torch.tensor(module.size(0), dtype=torch.float)).to(module.dtype) |
|
|
module.data.normal_(mean=0.0, std=embed_std) |
|
|
|
|
|
|
|
|
class HCXVisionModel(HCXVisionPreTrainedModel): |
|
|
def __init__( |
|
|
self, |
|
|
config: HCXVisionConfig, |
|
|
without_llm=False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(config) |
|
|
|
|
|
self.flag_changed_max_position_embeddings = False |
|
|
self.without_llm = without_llm |
|
|
|
|
|
vision_model_type = config.vision_config.model_type |
|
|
|
|
|
self.is_qwen_visual = False |
|
|
if vision_model_type == "qwen2_5_vl_visual": |
|
|
self.is_qwen_visual = True |
|
|
|
|
|
self.freeze_before_sampler = kwargs.pop("freeze_before_sampler", False) |
|
|
|
|
|
vision_config = config.vision_config |
|
|
vision_config.anyres = config.anyres |
|
|
vision_config.max_num_grids = config.max_num_grids |
|
|
vision_config.update({"torch_dtype": config.torch_dtype}) |
|
|
self.vision_config = vision_config |
|
|
if config.anyres: |
|
|
if not getattr(config, "possible_resolutions", []): |
|
|
possible_resolutions = [] |
|
|
if config.anyres: |
|
|
assert config.max_num_grids > 0 |
|
|
for i in range(1, config.max_num_grids + 1): |
|
|
for j in range(1, config.max_num_grids + 1): |
|
|
if i == 1 and j == 1 and not config.use_1x1_grid: |
|
|
continue |
|
|
if i * j <= config.max_num_grids: |
|
|
possible_resolutions.append([i, j]) |
|
|
|
|
|
possible_resolutions = [ |
|
|
[ys * vision_config.image_size, xs * vision_config.image_size] |
|
|
for ys, xs in possible_resolutions |
|
|
] |
|
|
self.config.possible_resolutions = possible_resolutions |
|
|
else: |
|
|
self.config.possible_resolutions = config.possible_resolutions |
|
|
|
|
|
if without_llm: |
|
|
|
|
|
|
|
|
vision_config.vison_pretrained_name_or_path = config.vision_model_name_or_path |
|
|
with no_init_weights(): |
|
|
if self.is_qwen_visual and is_ampere_or_newer(): |
|
|
vision_config._attn_implementation = "flash_attention_2" |
|
|
self.vision_model = AutoModel.from_config( |
|
|
vision_config, trust_remote_code=True |
|
|
) |
|
|
self.vision_model.gradient_checkpointing_enable() |
|
|
if config.mm_projector_type == "qwen_merger": |
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
def new_forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): |
|
|
The final hidden states of the model. |
|
|
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): |
|
|
The temporal, height and width of feature shape of each image in LLM. |
|
|
|
|
|
Returns: |
|
|
`torch.Tensor`: hidden_states. |
|
|
""" |
|
|
hidden_states = self.patch_embed(hidden_states) |
|
|
rotary_pos_emb = self.rot_pos_emb(grid_thw) |
|
|
window_index, cu_window_seqlens = self.get_window_index(grid_thw) |
|
|
cu_window_seqlens = torch.tensor( |
|
|
cu_window_seqlens, |
|
|
device=hidden_states.device, |
|
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
|
|
) |
|
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) |
|
|
|
|
|
seq_len, _ = hidden_states.size() |
|
|
hidden_states = hidden_states.reshape( |
|
|
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 |
|
|
) |
|
|
hidden_states = hidden_states[window_index, :, :] |
|
|
hidden_states = hidden_states.reshape(seq_len, -1) |
|
|
rotary_pos_emb = rotary_pos_emb.reshape( |
|
|
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 |
|
|
) |
|
|
rotary_pos_emb = rotary_pos_emb[window_index, :, :] |
|
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) |
|
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
|
|
position_embeddings = (emb.cos(), emb.sin()) |
|
|
|
|
|
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( |
|
|
dim=0, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
|
|
) |
|
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
|
|
|
|
|
for layer_num, blk in enumerate(self.blocks): |
|
|
if layer_num in self.fullatt_block_indexes: |
|
|
cu_seqlens_now = cu_seqlens |
|
|
else: |
|
|
cu_seqlens_now = cu_window_seqlens |
|
|
if self.gradient_checkpointing and self.training: |
|
|
hidden_states = self._gradient_checkpointing_func( |
|
|
blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings |
|
|
) |
|
|
else: |
|
|
hidden_states = blk( |
|
|
hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return hidden_states, window_index |
|
|
|
|
|
import types |
|
|
|
|
|
self.vision_model.forward = types.MethodType(new_forward, self.vision_model) |
|
|
self.vision_model.merger = nn.Identity() |
|
|
|
|
|
if hasattr(config, "text_config") and config.text_config is not None: |
|
|
text_config = config.text_config |
|
|
else: |
|
|
raise ValueError("text_config is not defined") |
|
|
text_config.update({"torch_dtype": config.torch_dtype}) |
|
|
if config.text_config.model_type in ["llama", "hyperclovax", "gpt2"]: |
|
|
text_config._attn_implementation = config._attn_implementation |
|
|
if text_config.model_type != "hyperclovax": |
|
|
text_config.logits_scaling = 1.0 |
|
|
|
|
|
text_config.vocab_size = ( |
|
|
text_config.padded_vocab_size if hasattr(text_config, "padded_vocab_size") else text_config.vocab_size |
|
|
) |
|
|
|
|
|
if not without_llm: |
|
|
with no_init_weights(): |
|
|
self.language_model = AutoModelForCausalLM.from_config(text_config, trust_remote_code=True) |
|
|
|
|
|
if config.text_config.model_type in ["llama", "hyperclovax", "gpt2"]: |
|
|
self.language_model.gradient_checkpointing_enable() |
|
|
self.num_queries_vis_abstractor = config.num_queries_vis_abstractor |
|
|
|
|
|
|
|
|
input_hidden_size = vision_config.hidden_size |
|
|
if vision_config.model_type == "qwen2_5_vl_visual": |
|
|
input_hidden_size = vision_config.out_hidden_size |
|
|
if config.mm_projector_type == "linear": |
|
|
self.mm_projector = nn.Linear(input_hidden_size, text_config.hidden_size) |
|
|
|
|
|
elif config.mm_projector_type == "cabstractor": |
|
|
self.mm_projector = CAbstractor( |
|
|
num_queries=self.num_queries_vis_abstractor, |
|
|
num_input_tokens=(self.vision_config.image_size // self.vision_config.patch_size) ** 2, |
|
|
encoder_hidden_size=input_hidden_size, |
|
|
hidden_size=input_hidden_size, |
|
|
output_hidden_size=text_config.hidden_size, |
|
|
pos_emb=config.proj_pos_emb, |
|
|
prenorm=config.proj_prenorm, |
|
|
) |
|
|
self.mm_projector.pos_emb.to(config.torch_dtype) |
|
|
elif config.mm_projector_type == "qwen_merger": |
|
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( |
|
|
Qwen2_5_VLPatchMerger, |
|
|
) |
|
|
|
|
|
self.mm_projector = Qwen2_5_VLPatchMerger(dim=text_config.hidden_size, context_dim=input_hidden_size) |
|
|
|
|
|
def new_forward(self, inputs) -> torch.Tensor: |
|
|
x, window_index = inputs |
|
|
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) |
|
|
reverse_indices = torch.argsort(window_index) |
|
|
x = x[reverse_indices, :] |
|
|
return x |
|
|
|
|
|
self.mm_projector.forward = types.MethodType(new_forward, self.mm_projector) |
|
|
|
|
|
else: |
|
|
self.mm_projector = VLM_Mlp( |
|
|
config.mm_projector_type, |
|
|
input_hidden_size, |
|
|
hidden_features=input_hidden_size, |
|
|
out_features=text_config.hidden_size, |
|
|
) |
|
|
self.use_nth_layer = config.use_nth_layer |
|
|
self.model_parallel = False |
|
|
self.device_map = None |
|
|
self.vision_model_use_no_grad = None |
|
|
|
|
|
self.text_config = text_config |
|
|
|
|
|
self.anyres = config.anyres |
|
|
self.unpad = config.unpad |
|
|
self.vision_input_chunk_size = kwargs.pop("vision_input_chunk_size", None) |
|
|
if self.anyres: |
|
|
self.image_newline = nn.Parameter(torch.empty(text_config.hidden_size, dtype=self.dtype)) |
|
|
|
|
|
self.is_safetensor_save = kwargs.get("is_safetensor_save", True) |
|
|
self._backward_compatibility_gradient_checkpointing() |
|
|
self.mm_projector.to(config.torch_dtype) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
pixel_values: Optional[List[List[torch.FloatTensor]]] = None, |
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = True, |
|
|
image_sizes: Optional[List[List[List[int]]]] = None, |
|
|
vision_query_lengths: Optional[List[List[int]]] = None, |
|
|
non_vision_query_lengths: Optional[List[List[int]]] = None, |
|
|
img_start_ids_list: Optional[List[List[int]]] = None, |
|
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
|
|
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, |
|
|
first_last_frames_slows: Optional[List[List[bool]]] = None, |
|
|
is_videos: Optional[List[List[bool]]] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
**kwargs, |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
""" |
|
|
:param input_ids: torch.int64 : torch.size([batchsize, variable)]) : SystemPrompt with Question text token indices for tokenizer. |
|
|
In positions where images are inputted, the value is replaced by config.img_start_id, which is a vocabulary index used to indicate the start of image data. |
|
|
:param pixel_values: List of List of 4D tensor (torch.float32) |
|
|
Each outer list corresponds to a batch and contains inner lists, each holding tensors for images in a sample. The structure accounts for samples with multiple images. |
|
|
:param past_key_values: None |
|
|
:param inputs_embeds: None |
|
|
:param use_cache: None |
|
|
:param output_attentions: Optional[bool] : get attention weights of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함) |
|
|
:param output_hidden_states: Optional[bool] : get hidden states of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함) |
|
|
:param image_sizes: Stacked as a List of List, representing image sizes (width, height). |
|
|
In cases where a sample contains no images, a single dummy image is included. |
|
|
:param vision_query_lengths: A List of List that stores the lengths when each image is converted into visual tokens for LLM input. |
|
|
In cases where a sample does not contain any images, an empty list is included. |
|
|
:param non_vision_query_lengths: contains the lengths of text tokens (excluding visual tokens) for each sample in a batch. |
|
|
:img_start_ids_list: contains the indices of the img_start_id tokens for each sample. |
|
|
:num_queries_vis_abstractors: A List of List that contains the number of visual tokens for each image grid. |
|
|
:num_queries_vis_abstractors_slow: A List of List that contains the number of visual tokens for the slow part when applying the slowfast algorithm to video frames. If the slowfast algorithm is not applied, it will have a value of None. |
|
|
:first_last_frames_slows: A List of List that contains the only first and last frames slow mode for each sample in a batch. |
|
|
:is_videos: A List of List that contains the boolean value indicating whether each sample in a batch is a video. |
|
|
:image_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder. |
|
|
:pixel_values_videos: A 2D tensor (torch.float32) for qwen2.5-vl visual encoder. |
|
|
:video_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder. |
|
|
:return: |
|
|
""" |
|
|
output_attentions = ( |
|
|
output_attentions if output_attentions is not None else self.config.vision_config.output_attentions |
|
|
) |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.vision_config.output_hidden_states |
|
|
) |
|
|
|
|
|
if inputs_embeds is None and past_key_values is None: |
|
|
inputs_embeds = self.extract_inputs_embeds( |
|
|
input_ids=input_ids, |
|
|
pixel_values=pixel_values, |
|
|
past_key_values=past_key_values, |
|
|
image_sizes=image_sizes, |
|
|
vision_query_lengths=vision_query_lengths, |
|
|
non_vision_query_lengths=non_vision_query_lengths, |
|
|
img_start_ids_list=img_start_ids_list, |
|
|
num_queries_vis_abstractors=num_queries_vis_abstractors, |
|
|
num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow, |
|
|
first_last_frames_slows=first_last_frames_slows, |
|
|
is_videos=is_videos, |
|
|
image_grid_thw=image_grid_thw, |
|
|
pixel_values_videos=pixel_values_videos, |
|
|
video_grid_thw=video_grid_thw, |
|
|
) |
|
|
|
|
|
if inputs_embeds is not None: |
|
|
input_ids = None |
|
|
|
|
|
outputs = self.language_model.base_model( |
|
|
input_ids=input_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
return outputs |
|
|
|
|
|
def determine_non_vision_query_lengths(self, input_ids, pad_id, img_start_id): |
|
|
"""non_vision_query_lengths 를 계산하는 함수 |
|
|
input_ids 가 collate 될때, 오른쪽에 pad_id 가 채워지기 때문에 이 값을 찾는 방식을 통해 계산됨 |
|
|
또한 img_start_id 는 visual token 이 들어서는 자리이기 때문에, 해당 indices 은 제거 |
|
|
""" |
|
|
non_vision_query_lengths = [] |
|
|
batch_size, len_seq = input_ids.size(0), input_ids.size(1) |
|
|
|
|
|
for i in range(batch_size): |
|
|
temp_idx = (input_ids[i] == pad_id).nonzero() |
|
|
eos_idx = temp_idx[0, 0].item() if len(temp_idx) > 0 else len_seq |
|
|
num_imgs = (input_ids[i] == img_start_id).sum().item() |
|
|
non_vision_query_lengths.append(eos_idx - num_imgs) |
|
|
|
|
|
if all([pad_id in input_id for input_id in input_ids.tolist()]): |
|
|
non_vision_query_lengths = [ |
|
|
non_vision_query_length + 1 for non_vision_query_length in non_vision_query_lengths |
|
|
] |
|
|
|
|
|
return non_vision_query_lengths |
|
|
|
|
|
def determine_vision_query_lengths(self, image_features, image_cnts): |
|
|
"""vision_query_lengths 를 계산하는 함수 |
|
|
image_features tensor 의 shape 을 통해 계산된다. |
|
|
이미지가 1장도 없는 sample 의 경우 dummy image 1장이 들어가기 때문에, 따로 빈 list 처리 또한 추가 |
|
|
""" |
|
|
vision_query_lengths = [ |
|
|
[image_feature.size(0) for image_feature in image_feature_list] for image_feature_list in image_features |
|
|
] |
|
|
|
|
|
for i, image_cnt in enumerate(image_cnts): |
|
|
if image_cnt == 0: |
|
|
assert len(vision_query_lengths[i]) == 1 |
|
|
vision_query_lengths[i] = [] |
|
|
|
|
|
return vision_query_lengths |
|
|
|
|
|
|
|
|
def get_input_embeddings(self): |
|
|
if self.without_llm: |
|
|
return None |
|
|
else: |
|
|
return self.language_model.get_input_embeddings() |
|
|
|
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.language_model.set_input_embeddings(value) |
|
|
|
|
|
|
|
|
def get_output_embeddings(self): |
|
|
if self.without_llm: |
|
|
return None |
|
|
else: |
|
|
return self.language_model.get_output_embeddings() |
|
|
|
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.language_model.set_output_embeddings(new_embeddings) |
|
|
|
|
|
|
|
|
def set_decoder(self, decoder): |
|
|
self.language_model.set_decoder(decoder) |
|
|
|
|
|
|
|
|
def get_decoder(self): |
|
|
return self.language_model.get_decoder() |
|
|
|
|
|
|
|
|
def tie_weights(self): |
|
|
if self.without_llm: |
|
|
return None |
|
|
else: |
|
|
return self.language_model.tie_weights() |
|
|
|
|
|
|
|
|
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: |
|
|
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) |
|
|
|
|
|
self.config.text_config.vocab_size = model_embeds.num_embeddings |
|
|
self.vocab_size = model_embeds.num_embeddings |
|
|
return model_embeds |
|
|
|
|
|
def extract_inputs_embeds( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
pixel_values: Optional[List[List[torch.FloatTensor]]] = None, |
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
|
image_sizes: Optional[List[List[List[int]]]] = None, |
|
|
vision_query_lengths: Optional[List[List[int]]] = None, |
|
|
non_vision_query_lengths: Optional[List[int]] = None, |
|
|
img_start_ids_list: Optional[List[List[int]]] = None, |
|
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
|
|
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, |
|
|
first_last_frames_slows: Optional[List[List[bool]]] = None, |
|
|
is_videos: Optional[List[List[bool]]] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
): |
|
|
""" |
|
|
:param input_ids: torch.int64 : torch.size([batchsize, variable)]) : SystemPrompt with Question text token indices for tokenizer. |
|
|
In positions where images are inputted, the value is replaced by config.img_start_id, which is a vocabulary index used to indicate the start of image data. |
|
|
In cases where a sample contains no images, a single dummy image is included. |
|
|
:param pixel_values: List of List of 4D tensor (torch.float32) |
|
|
Each outer list corresponds to a batch and contains inner lists, each holding tensors for images in a sample. The structure accounts for samples with multiple images. |
|
|
:param past_key_values: None : (batch_size, num_heads, sequence_length - 1, embed_size_per_head): Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up |
|
|
:param image_sizes: Stacked as a List of List, representing image sizes (width, height). |
|
|
In cases where a sample contains no images, a single dummy image is included. |
|
|
:param vision_query_lengths: A List of List that stores the lengths when each image is converted into visual tokens for LLM input. |
|
|
In cases where a sample does not contain any images, an empty list is included. |
|
|
:param non_vision_query_lengths: contains the lengths of text tokens (excluding visual tokens) for each sample in a batch. |
|
|
:img_start_ids_list: contains the indices of the img_start_id tokens for each sample. |
|
|
:num_queries_vis_abstractors: A List of List that contains the number of visual tokens for each image grid. |
|
|
:num_queries_vis_abstractors_slow: A List of List that contains the number of visual tokens for the slow part when applying the slowfast algorithm to video frames. If the slowfast algorithm is not applied, it will have a value of None. |
|
|
:first_last_frames_slows: A List of bool that contains the information of whether the slowfast algorithm is applied to the first or last frames of the video. |
|
|
:is_videos: A List of List that contains the boolean value indicating whether each sample in a batch is a video. |
|
|
:image_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder. |
|
|
:pixel_values_videos: A 2D tensor (torch.float32) for qwen2.5-vl visual encoder. |
|
|
:video_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder. |
|
|
:return: |
|
|
""" |
|
|
inputs_embeds = None |
|
|
if past_key_values: |
|
|
pass |
|
|
else: |
|
|
if self.is_qwen_visual: |
|
|
inputs_embeds = self.get_input_embeddings()(input_ids) |
|
|
context_vision_model = torch.no_grad() if self.config.freeze_encoder else contextlib.nullcontext() |
|
|
|
|
|
if pixel_values is not None: |
|
|
with context_vision_model: |
|
|
image_features = self.vision_model(pixel_values, grid_thw=image_grid_thw) |
|
|
image_features = self.mm_projector(image_features) |
|
|
|
|
|
if img_start_ids_list is None: |
|
|
image_cnts = (input_ids == self.config.img_start_id).sum(dim=1).tolist() |
|
|
else: |
|
|
image_cnts = [len(img_start_ids) for img_start_ids in img_start_ids_list] |
|
|
|
|
|
mask = input_ids.eq(self.config.img_start_id) |
|
|
positions = mask.nonzero(as_tuple=False) |
|
|
|
|
|
batch_idx = positions[:, 0] |
|
|
seq_idx = positions[:, 1] |
|
|
|
|
|
if sum(image_cnts) == 0: |
|
|
image_features = image_features[0:0] |
|
|
inputs_embeds[batch_idx, seq_idx, :] = image_features.to(device=inputs_embeds.device) |
|
|
|
|
|
if pixel_values_videos is not None: |
|
|
with context_vision_model: |
|
|
video_features = self.vision_model(pixel_values_videos, grid_thw=video_grid_thw) |
|
|
video_features = self.mm_projector(video_features) |
|
|
|
|
|
video_cnts = (input_ids == self.config.video_start_id).sum(dim=1).tolist() |
|
|
mask = input_ids.eq(self.config.video_start_id) |
|
|
positions = mask.nonzero(as_tuple=False) |
|
|
|
|
|
batch_idx = positions[:, 0] |
|
|
seq_idx = positions[:, 1] |
|
|
|
|
|
if sum(video_cnts) == 0: |
|
|
video_features = video_features[0:0] |
|
|
inputs_embeds[batch_idx, seq_idx, :] = video_features.to(device=inputs_embeds.device) |
|
|
else: |
|
|
|
|
|
len_pixel_values = [len(pixel_value) for pixel_value in pixel_values] |
|
|
concat_pixel_values = torch.cat(list(chain(*pixel_values)), dim=0) |
|
|
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_adaptive_anyres = num_queries_vis_abstractors is not None and any( |
|
|
self.num_queries_vis_abstractor != num_queries_vis_abstractor |
|
|
for sublist in num_queries_vis_abstractors |
|
|
for num_queries_vis_abstractor in sublist |
|
|
) |
|
|
if not is_adaptive_anyres: |
|
|
image_sizes = list(chain(*image_sizes)) |
|
|
if is_videos is not None: |
|
|
is_videos = list(chain(*is_videos)) |
|
|
else: |
|
|
is_videos = [False] * len(image_sizes) |
|
|
|
|
|
group_ids = None |
|
|
else: |
|
|
|
|
|
|
|
|
is_cabstractor = False |
|
|
for submodule in self.mm_projector.modules(): |
|
|
if isinstance(submodule, CAbstractor): |
|
|
is_cabstractor = True |
|
|
break |
|
|
assert is_cabstractor |
|
|
|
|
|
assert num_queries_vis_abstractors_slow is not None |
|
|
|
|
|
num_queries_vis_abstractors, num_grids, image_sizes, is_videos, group_ids = ( |
|
|
self.compute_adaptive_params( |
|
|
pixel_values, |
|
|
num_queries_vis_abstractors, |
|
|
num_queries_vis_abstractors_slow, |
|
|
image_sizes, |
|
|
is_videos, |
|
|
first_last_frames_slows, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if torch.is_grad_enabled(): |
|
|
if self.vision_model_use_no_grad is None: |
|
|
self.vision_model_use_no_grad = all( |
|
|
not p.requires_grad for p in self.vision_model.vision_model.encoder.parameters() |
|
|
) |
|
|
context_vision_model = torch.no_grad() if self.vision_model_use_no_grad else contextlib.nullcontext() |
|
|
if self.vision_input_chunk_size is not None: |
|
|
|
|
|
chunk_size = self.vision_input_chunk_size |
|
|
|
|
|
local_batch_size = torch.tensor([concat_pixel_values.size(0)], device=concat_pixel_values.device) |
|
|
gathered_batch_sizes = [ |
|
|
torch.zeros_like(local_batch_size) for _ in range(torch.distributed.get_world_size()) |
|
|
] |
|
|
torch.distributed.all_gather(gathered_batch_sizes, local_batch_size) |
|
|
gathered_batch_sizes = torch.stack(gathered_batch_sizes) |
|
|
max_batch_size = gathered_batch_sizes.max().item() |
|
|
|
|
|
n_chunks = math.ceil(max_batch_size / chunk_size) |
|
|
|
|
|
if is_adaptive_anyres: |
|
|
chunk_num_queries_vis_abstractors, chunk_num_grids, chunk_is_splits = ( |
|
|
self.split_adaptive_params( |
|
|
num_queries_vis_abstractors, |
|
|
num_grids, |
|
|
chunk_size, |
|
|
n_chunks, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
dummy_shape = (1,) + tuple(concat_pixel_values.shape[1:]) |
|
|
dummy = torch.zeros( |
|
|
dummy_shape, dtype=concat_pixel_values.dtype, device=concat_pixel_values.device |
|
|
).to(self.vision_model.dtype) |
|
|
|
|
|
else: |
|
|
|
|
|
chunk_size = concat_pixel_values.size(0) |
|
|
n_chunks = 1 |
|
|
|
|
|
image_forward_outs = [] |
|
|
|
|
|
for i in range(n_chunks): |
|
|
start = i * chunk_size |
|
|
end = (i + 1) * chunk_size |
|
|
|
|
|
chunk = concat_pixel_values[start:end].to(self.vision_model.dtype) |
|
|
current_chunk_size = chunk.size(0) |
|
|
|
|
|
|
|
|
if current_chunk_size == 0: |
|
|
chunk = dummy |
|
|
|
|
|
|
|
|
if self.use_nth_layer == -1: |
|
|
|
|
|
self.vision_model.vision_model.post_layernorm = nn.Identity() |
|
|
with context_vision_model: |
|
|
outs = self.vision_model(chunk) |
|
|
outs = outs.last_hidden_state[:, visual_token_idx:] |
|
|
else: |
|
|
with context_vision_model: |
|
|
outs = self.vision_model(chunk, output_hidden_states=True) |
|
|
outs = outs.hidden_states[self.use_nth_layer][:, visual_token_idx:] |
|
|
if self.vision_model_use_no_grad: |
|
|
outs = outs.detach().requires_grad_(True) |
|
|
if not is_adaptive_anyres: |
|
|
if self.freeze_before_sampler and self.training: |
|
|
outs = self.mm_projector(outs, freeze_before_sampler=True) |
|
|
else: |
|
|
outs = self.mm_projector(outs) |
|
|
if current_chunk_size > 0: |
|
|
image_forward_outs.append(outs) |
|
|
else: |
|
|
if n_chunks != 1: |
|
|
current_num_queries_vis_abstractors = chunk_num_queries_vis_abstractors[i] |
|
|
current_num_grids = chunk_num_grids[i] |
|
|
else: |
|
|
current_num_queries_vis_abstractors = num_queries_vis_abstractors |
|
|
current_num_grids = num_grids |
|
|
if self.freeze_before_sampler and self.training: |
|
|
outs = self.mm_projector( |
|
|
outs, |
|
|
num_queries_vis_abstractors=current_num_queries_vis_abstractors, |
|
|
num_grids=current_num_grids, |
|
|
freeze_before_sampler=True, |
|
|
) |
|
|
else: |
|
|
outs = self.mm_projector( |
|
|
outs, |
|
|
num_queries_vis_abstractors=current_num_queries_vis_abstractors, |
|
|
num_grids=current_num_grids, |
|
|
) |
|
|
if current_chunk_size > 0: |
|
|
if i > 0 and chunk_is_splits[i - 1]: |
|
|
|
|
|
image_forward_outs[-1] = torch.cat([image_forward_outs[-1], outs[0]], dim=0) |
|
|
image_forward_outs.extend(outs[1:]) |
|
|
else: |
|
|
image_forward_outs.extend(outs) |
|
|
|
|
|
if not is_adaptive_anyres: |
|
|
|
|
|
|
|
|
image_forward_outs = torch.cat(image_forward_outs, dim=0).to(image_forward_outs[0].dtype) |
|
|
|
|
|
if img_start_ids_list is None: |
|
|
image_cnts = (input_ids == self.config.img_start_id).sum(dim=1).tolist() |
|
|
else: |
|
|
image_cnts = [len(img_start_ids) for img_start_ids in img_start_ids_list] |
|
|
|
|
|
if self.anyres: |
|
|
split_sizes = [pixel_value.shape[0] for pixel_value in chain(*pixel_values)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
if not is_adaptive_anyres: |
|
|
image_features = [image_forward_out for image_forward_out in image_forward_outs] |
|
|
else: |
|
|
image_features = [image_forward_out.unsqueeze(0) for image_forward_out in image_forward_outs] |
|
|
|
|
|
image_features = [ |
|
|
image_features[sum(len_pixel_values[:i]) : sum(len_pixel_values[: i + 1])] |
|
|
for i in range(len(len_pixel_values)) |
|
|
] |
|
|
|
|
|
|
|
|
if self.without_llm: |
|
|
return image_features |
|
|
|
|
|
batch_size = input_ids.size(0) |
|
|
image_feature_dim = image_features[0][0].size(1) |
|
|
image_feature_dtype = image_features[0][0].dtype |
|
|
|
|
|
if img_start_ids_list is None: |
|
|
image_cnts = (input_ids == self.config.img_start_id).sum(dim=1).tolist() |
|
|
else: |
|
|
image_cnts = [len(img_start_ids) for img_start_ids in img_start_ids_list] |
|
|
|
|
|
if non_vision_query_lengths is None: |
|
|
non_vision_query_lengths = self.determine_non_vision_query_lengths( |
|
|
input_ids, self.config.text_config.pad_token_id, self.config.img_start_id |
|
|
) |
|
|
|
|
|
if vision_query_lengths is None: |
|
|
vision_query_lengths = self.determine_vision_query_lengths(image_features, image_cnts) |
|
|
|
|
|
|
|
|
len_inputs_embeds = max( |
|
|
[ |
|
|
sum(vision_query_length) + non_vision_query_length |
|
|
for non_vision_query_length, vision_query_length in zip( |
|
|
non_vision_query_lengths, vision_query_lengths |
|
|
) |
|
|
] |
|
|
) |
|
|
|
|
|
inputs_embeds = torch.zeros( |
|
|
[batch_size, len_inputs_embeds, image_feature_dim], |
|
|
dtype=image_feature_dtype, |
|
|
device=self.device, |
|
|
requires_grad=True, |
|
|
).clone() |
|
|
|
|
|
|
|
|
temp_embeds = self.get_input_embeddings()(input_ids) |
|
|
|
|
|
|
|
|
for batch_idx, sample in enumerate(input_ids): |
|
|
|
|
|
non_vision_query_length = non_vision_query_lengths[batch_idx] |
|
|
|
|
|
sample = sample[: non_vision_query_length + image_cnts[batch_idx]] |
|
|
|
|
|
if image_cnts[batch_idx] == 0: |
|
|
temp_idx = 0 |
|
|
|
|
|
|
|
|
inputs_embeds[batch_idx, :non_vision_query_length] = temp_embeds[batch_idx][ |
|
|
:non_vision_query_length |
|
|
] |
|
|
inputs_embeds[batch_idx, temp_idx:temp_idx] = image_features[batch_idx][0][ |
|
|
0:0 |
|
|
] |
|
|
else: |
|
|
if img_start_ids_list is None: |
|
|
img_start_ids = (sample == self.config.img_start_id).nonzero() |
|
|
else: |
|
|
img_start_ids = img_start_ids_list[batch_idx] |
|
|
assert len(img_start_ids) == image_cnts[batch_idx] == len(image_features[batch_idx]) |
|
|
|
|
|
input_start, temp_start = 0, 0 |
|
|
|
|
|
|
|
|
for multi_img_idx, img_start_idx in enumerate(img_start_ids): |
|
|
|
|
|
token_len = img_start_idx - temp_start |
|
|
|
|
|
|
|
|
inputs_embeds[batch_idx, input_start : input_start + token_len] = temp_embeds[ |
|
|
batch_idx, temp_start : temp_start + token_len |
|
|
] |
|
|
|
|
|
|
|
|
inputs_embeds[ |
|
|
batch_idx, |
|
|
input_start |
|
|
+ token_len : input_start |
|
|
+ token_len |
|
|
+ vision_query_lengths[batch_idx][multi_img_idx], |
|
|
] = image_features[batch_idx][multi_img_idx] |
|
|
|
|
|
|
|
|
input_start += token_len + vision_query_lengths[batch_idx][multi_img_idx] |
|
|
temp_start += token_len + 1 |
|
|
|
|
|
|
|
|
token_len = min(sample[temp_start:].size(0), inputs_embeds.size(1) - input_start) |
|
|
inputs_embeds[batch_idx, input_start : input_start + token_len] = temp_embeds[ |
|
|
batch_idx, temp_start : temp_start + token_len |
|
|
] |
|
|
return inputs_embeds |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
|
|
*model_args, |
|
|
**kwargs, |
|
|
): |
|
|
model = super().from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
*model_args, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) |
|
|
return model |
|
|
|
|
|
def save_pretrained( |
|
|
self, |
|
|
save_directory: Union[str, os.PathLike], |
|
|
*args, |
|
|
**kwargs, |
|
|
): |
|
|
super().register_for_auto_class("AutoModel") |
|
|
self.config.register_for_auto_class() |
|
|
super().save_pretrained(save_directory, *args, **kwargs) |
|
|
|
|
|
def compute_adaptive_params( |
|
|
self, |
|
|
pixel_values: Optional[List[List[torch.FloatTensor]]] = None, |
|
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
|
|
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, |
|
|
image_sizes: Optional[List[List[List[int]]]] = None, |
|
|
is_videos: Optional[List[List[bool]]] = None, |
|
|
first_last_frames_slows: Optional[List[List[bool]]] = None, |
|
|
): |
|
|
|
|
|
assert all( |
|
|
all(isinstance(value, int) and value >= 0 for value in sublist) for sublist in num_queries_vis_abstractors |
|
|
), "All values in num_queries_vis_abstractors must be integers >= 0." |
|
|
|
|
|
assert all( |
|
|
all(isinstance(value, int) and value >= 0 for value in sublist) |
|
|
for sublist in num_queries_vis_abstractors_slow |
|
|
), "All values in num_queries_vis_abstractors_slow must be integers >= 0." |
|
|
|
|
|
assert is_videos is not None |
|
|
|
|
|
|
|
|
is_first_images = [] |
|
|
is_last_images = [] |
|
|
for is_video in is_videos: |
|
|
for idx, is_video_item in enumerate(is_video): |
|
|
if idx == 0: |
|
|
is_first_images.append(True) |
|
|
else: |
|
|
is_first_images.append(False) |
|
|
if idx == len(is_video) - 1: |
|
|
is_last_images.append(True) |
|
|
else: |
|
|
is_last_images.append(False) |
|
|
|
|
|
num_queries_vis_abstractors = list(chain(*num_queries_vis_abstractors)) |
|
|
num_queries_vis_abstractors_slow = list(chain(*num_queries_vis_abstractors_slow)) |
|
|
image_sizes = list(chain(*image_sizes)) |
|
|
is_videos = list(chain(*is_videos)) |
|
|
first_last_frames_slows = list(chain(*first_last_frames_slows)) |
|
|
|
|
|
|
|
|
use_slowfast = any([num_query > 0 for num_query in num_queries_vis_abstractors_slow]) |
|
|
|
|
|
num_grids = [pixel_value.shape[0] for pixel_value in chain(*pixel_values)] |
|
|
num_grids = [0] + num_grids |
|
|
group_ids = [] |
|
|
|
|
|
if use_slowfast: |
|
|
new_num_grids = [num_grids[0]] |
|
|
new_num_queries = [] |
|
|
new_image_sizes = [] |
|
|
new_is_videos = [] |
|
|
|
|
|
|
|
|
|
|
|
for ( |
|
|
num_query, |
|
|
num_query_slow, |
|
|
num_grid, |
|
|
image_size, |
|
|
is_video, |
|
|
first_last_frames_slow, |
|
|
is_first_image, |
|
|
is_last_image, |
|
|
) in zip( |
|
|
num_queries_vis_abstractors, |
|
|
num_queries_vis_abstractors_slow, |
|
|
num_grids[1:], |
|
|
image_sizes, |
|
|
is_videos, |
|
|
first_last_frames_slows, |
|
|
is_first_images, |
|
|
is_last_images, |
|
|
): |
|
|
|
|
|
if not first_last_frames_slow and num_query_slow > 0: |
|
|
assert is_video is True |
|
|
|
|
|
this_group_ids = [group_ids[-1][-1] + 1 if group_ids else 0] |
|
|
|
|
|
|
|
|
new_num_grids.append(new_num_grids[-1] + 1) |
|
|
new_num_queries.append(num_query_slow) |
|
|
new_image_sizes.append(image_size) |
|
|
new_is_videos.append(is_video) |
|
|
|
|
|
if num_grid >= 2: |
|
|
|
|
|
new_num_grids.append(new_num_grids[-1] + num_grid - 1) |
|
|
new_num_queries.append(num_query) |
|
|
new_image_sizes.append(image_size) |
|
|
new_is_videos.append(is_video) |
|
|
this_group_ids.append(this_group_ids[-1] + 1) |
|
|
|
|
|
group_ids.append(this_group_ids) |
|
|
elif ( |
|
|
first_last_frames_slow and num_query_slow > 0 and (is_first_image or is_last_image) |
|
|
): |
|
|
|
|
|
assert is_video is True |
|
|
|
|
|
this_group_ids = [group_ids[-1][-1] + 1 if group_ids else 0] |
|
|
|
|
|
if num_grid == 1: |
|
|
|
|
|
new_num_grids.append(new_num_grids[-1] + 1) |
|
|
new_num_queries.append(num_query_slow) |
|
|
new_image_sizes.append(image_size) |
|
|
new_is_videos.append(is_video) |
|
|
|
|
|
if num_grid >= 2: |
|
|
if is_first_image: |
|
|
|
|
|
new_num_grids.append(new_num_grids[-1] + 1) |
|
|
new_num_queries.append(num_query_slow) |
|
|
new_image_sizes.append(image_size) |
|
|
new_is_videos.append(is_video) |
|
|
|
|
|
new_num_grids.append(new_num_grids[-1] + num_grid - 1) |
|
|
new_num_queries.append(num_query) |
|
|
new_image_sizes.append(image_size) |
|
|
new_is_videos.append(is_video) |
|
|
this_group_ids.append(this_group_ids[-1] + 1) |
|
|
elif is_last_image: |
|
|
|
|
|
new_num_grids.append(new_num_grids[-1] + num_grid - 1) |
|
|
new_num_queries.append(num_query) |
|
|
new_image_sizes.append(image_size) |
|
|
new_is_videos.append(is_video) |
|
|
|
|
|
new_num_grids.append(new_num_grids[-1] + 1) |
|
|
new_num_queries.append(num_query_slow) |
|
|
new_image_sizes.append(image_size) |
|
|
new_is_videos.append(is_video) |
|
|
this_group_ids.append(this_group_ids[-1] + 1) |
|
|
else: |
|
|
raise Exception("This case should not be reached.") |
|
|
group_ids.append(this_group_ids) |
|
|
|
|
|
else: |
|
|
|
|
|
new_num_grids.append(new_num_grids[-1] + num_grid) |
|
|
new_num_queries.append(num_query) |
|
|
new_image_sizes.append(image_size) |
|
|
new_is_videos.append(is_video) |
|
|
|
|
|
start_group_id = group_ids[-1][-1] + 1 if group_ids else 0 |
|
|
group_ids.append([start_group_id]) |
|
|
|
|
|
num_grids = new_num_grids |
|
|
num_queries_vis_abstractors = new_num_queries |
|
|
image_sizes = new_image_sizes |
|
|
is_videos = new_is_videos |
|
|
else: |
|
|
num_grids = [sum(num_grids[:i]) for i in range(1, len(num_grids) + 1)] |
|
|
group_ids = [[group_id] for group_id in range(len(is_videos))] |
|
|
|
|
|
return num_queries_vis_abstractors, num_grids, image_sizes, is_videos, group_ids |
|
|
|
|
|
def split_adaptive_params( |
|
|
self, num_queries_vis_abstractors, num_grids, chunk_size: int, n_chunks: int |
|
|
): |
|
|
""" |
|
|
num_grids/num_queries 를 chunk_size 단위로 최대 n_chunks 만큼 자른다. |
|
|
실제 데이터가 부족하면 남은 chunk 는 더미([0,1]) 로 채운다. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
chunk_qs : List[List[int]] |
|
|
chunk_grids: List[List[int]] |
|
|
각 원소 길이는 동일하며, 전체 길이는 정확히 n_chunks. |
|
|
""" |
|
|
total_len = num_grids[-1] |
|
|
chunk_qs, chunk_grids, is_splits = [], [], [] |
|
|
|
|
|
|
|
|
|
|
|
slices = list(zip(num_grids[:-1], num_grids[1:], num_queries_vis_abstractors)) |
|
|
slice_idx = 0 |
|
|
|
|
|
for chunk_idx in range(n_chunks): |
|
|
start = chunk_idx * chunk_size |
|
|
end = start + chunk_size |
|
|
|
|
|
|
|
|
if start >= total_len: |
|
|
chunk_grids.append([0, 1]) |
|
|
chunk_qs.append([num_queries_vis_abstractors[-1]]) |
|
|
is_splits.append(False) |
|
|
continue |
|
|
|
|
|
grids_in_chunk = [0] |
|
|
qs_in_chunk = [] |
|
|
|
|
|
|
|
|
while slice_idx < len(slices) and slices[slice_idx][1] <= start: |
|
|
slice_idx += 1 |
|
|
|
|
|
is_split = False |
|
|
j = slice_idx |
|
|
while j < len(slices) and slices[j][0] < end: |
|
|
s, e, q = slices[j] |
|
|
|
|
|
|
|
|
left = max(s, start) |
|
|
right = min(e, end) |
|
|
off = right - start |
|
|
|
|
|
if off not in grids_in_chunk: |
|
|
grids_in_chunk.append(off) |
|
|
qs_in_chunk.append(q) |
|
|
if right == end and e != end: |
|
|
is_split = True |
|
|
|
|
|
|
|
|
if e > end: |
|
|
break |
|
|
j += 1 |
|
|
slice_idx = j |
|
|
|
|
|
|
|
|
final_off = min(end, total_len) - start |
|
|
if grids_in_chunk[-1] != final_off: |
|
|
grids_in_chunk.append(final_off) |
|
|
qs_in_chunk.append(qs_in_chunk[-1] if qs_in_chunk else num_queries_vis_abstractors[-1]) |
|
|
|
|
|
is_split = True |
|
|
|
|
|
chunk_grids.append(grids_in_chunk) |
|
|
chunk_qs.append(qs_in_chunk) |
|
|
is_splits.append(is_split) |
|
|
|
|
|
return chunk_qs, chunk_grids, is_splits |
|
|
|
|
|
|
|
|
class HCXVisionForCausalLM(HCXVisionPreTrainedModel, GenerationMixin): |
|
|
def __init__( |
|
|
self, |
|
|
config: HCXVisionConfig, |
|
|
without_llm=False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(config, without_llm=without_llm, **kwargs) |
|
|
text_config = config.get_text_config() |
|
|
self.model = HCXVisionModel(config=config, **kwargs) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
pixel_values: Optional[List[List[torch.FloatTensor]]] = None, |
|
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = True, |
|
|
image_sizes: Optional[List[List[List[int]]]] = None, |
|
|
vision_query_lengths: Optional[List[List[int]]] = None, |
|
|
non_vision_query_lengths: Optional[List[List[int]]] = None, |
|
|
img_start_ids_list: Optional[List[List[int]]] = None, |
|
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
|
|
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, |
|
|
first_last_frames_slows: Optional[List[List[bool]]] = None, |
|
|
is_videos: Optional[List[List[bool]]] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
|
**kwargs, |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
""" |
|
|
:param input_ids: torch.int64 : torch.size([batchsize, variable)]) : SystemPrompt with Question text token indices for tokenizer. |
|
|
In positions where images are inputted, the value is replaced by config.img_start_id, which is a vocabulary index used to indicate the start of image data. |
|
|
:param pixel_values: List of List of 4D tensor (torch.float32) |
|
|
Each outer list corresponds to a batch and contains inner lists, each holding tensors for images in a sample. The structure accounts for samples with multiple images. |
|
|
:param past_key_values: None |
|
|
:param inputs_embeds: None |
|
|
:param labels: Optional[torch.int64] : [batchsize, variable (input_ids.size(1)+ num visual tokens)] visual token 들은 모두 IGNORE_INDEX |
|
|
:param use_cache: None |
|
|
:param output_attentions: Optional[bool] : get attention weights of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함) |
|
|
:param output_hidden_states: Optional[bool] : get hidden states of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함) |
|
|
:param image_sizes: Stacked as a List of List, representing image sizes (width, height). |
|
|
In cases where a sample contains no images, a single dummy image is included. |
|
|
:param vision_query_lengths: A List of List that stores the lengths when each image is converted into visual tokens for LLM input. |
|
|
In cases where a sample does not contain any images, an empty list is included. |
|
|
:param non_vision_query_lengths: contains the lengths of text tokens (excluding visual tokens) for each sample in a batch. |
|
|
:img_start_ids_list: contains the indices of the img_start_id tokens for each sample. |
|
|
:num_queries_vis_abstractors: A List of List that contains the number of visual tokens for each image grid. |
|
|
:num_queries_vis_abstractors_slow: A List of List that contains the number of visual tokens for the slow part when applying the slowfast algorithm to video frames. If the slowfast algorithm is not applied, it will have a value of None. |
|
|
:first_last_frames_slows: A List of List that contains the only first and last frames slow mode for each sample in a batch. |
|
|
:is_videos: A List of List that contains the boolean value indicating whether each sample in a batch is a video. |
|
|
:image_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder. |
|
|
:pixel_values_videos: A 2D tensor (torch.float32) for qwen2.5-vl visual encoder. |
|
|
:video_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder. |
|
|
:return: |
|
|
""" |
|
|
loss = None |
|
|
logits = None |
|
|
outputs = self.model.forward( |
|
|
input_ids=input_ids, |
|
|
pixel_values=pixel_values, |
|
|
past_key_values=past_key_values, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
image_sizes=image_sizes, |
|
|
vision_query_lengths=vision_query_lengths, |
|
|
non_vision_query_lengths=non_vision_query_lengths, |
|
|
img_start_ids_list=img_start_ids_list, |
|
|
num_queries_vis_abstractors=num_queries_vis_abstractors, |
|
|
num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow, |
|
|
first_last_frames_slows=first_last_frames_slows, |
|
|
is_videos=is_videos, |
|
|
image_grid_thw=image_grid_thw, |
|
|
pixel_values_videos=pixel_values_videos, |
|
|
video_grid_thw=video_grid_thw, |
|
|
) |
|
|
hidden_states = outputs.last_hidden_state |
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
|
logits = self.model.language_model.lm_head(hidden_states[:, slice_indices, :]) * getattr( |
|
|
self.config.text_config, "logits_scaling", 1 |
|
|
) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = self.loss_function( |
|
|
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs |
|
|
) |
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def inference( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
pixel_values: Optional[ |
|
|
Union[List[List[torch.FloatTensor]], torch.FloatTensor] |
|
|
] = None, |
|
|
image_sizes: Optional[List[List[List[int]]]] = None, |
|
|
vision_query_lengths: Optional[List[List[int]]] = None, |
|
|
non_vision_query_lengths: Optional[List[int]] = None, |
|
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
|
|
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, |
|
|
first_last_frames_slows: Optional[List[List[bool]]] = None, |
|
|
is_videos: Optional[List[List[bool]]] = None, |
|
|
img_start_ids_list: Optional[List[List[int]]] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
max_length: int = 196, |
|
|
min_length: int = 2, |
|
|
do_sample: bool = True, |
|
|
num_beams: int = 1, |
|
|
top_p: float = 0.6, |
|
|
top_k: int = 0, |
|
|
temperature: float = 0.5, |
|
|
repetition_penalty: float = 1.0, |
|
|
length_penalty: int = 1, |
|
|
early_stopping: Union[bool, str] = False, |
|
|
use_cache: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
:param input_ids: torch.int64 : torch.size([batchsize, variable)]) : SystemPrompt with Question text token indices for tokenizer. |
|
|
In positions where images are inputted, the value is replaced by config.img_start_id, which is a vocabulary index used to indicate the start of image data. |
|
|
In cases where a sample contains no images, a single dummy image is included. |
|
|
:param pixel_values: List of List of 4D tensor (torch.float32) |
|
|
Each outer list corresponds to a batch and contains inner lists, each holding tensors for images in a sample. The structure accounts for samples with multiple images. |
|
|
:param attention_mask: not used |
|
|
:param max_length: int : The maximum length the generated tokens can have. Corresponds to the length of the input prompt + max_new_tokens. |
|
|
:param min_length: int : The minimum length of the sequence to be generated. Corresponds to the length of the input prompt + min_new_tokens. |
|
|
:param num_beams: int : Number of beams for beam search. 1 means no beam search. |
|
|
:param top_k: int : The number of highest probability vocabulary tokens to keep for top-k-filtering. |
|
|
:param temperature: float : The value used to modulate the next token probabilities. ( scores / self.temperature ) |
|
|
:param repetition_penalty: float : The parameter for repetition penalty. |
|
|
:param length_penalty: int : It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence. |
|
|
:param early_stopping: Union[bool, str] : True, where the generation stops as soon as there are num_beams complete candidates; |
|
|
False, where an heuristic is applied and the generation stops when is it very unlikely to find better candidates; |
|
|
"never", where the beam search procedure only stops when there cannot be better candidates (canonical beam search algorithm) |
|
|
:param use_cache: bool : Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. |
|
|
:param verbose: bool : print debug mention |
|
|
:param image_sizes: Stacked as a List of List, representing image sizes (width, height). |
|
|
In cases where a sample contains no images, a single dummy image is included. |
|
|
:param vision_query_lengths: A List of List that stores the lengths when each image is converted into visual tokens for LLM input. |
|
|
In cases where a sample does not contain any images, an empty list is included. |
|
|
:param non_vision_query_lengths: contains the lengths of text tokens (excluding visual tokens) for each sample in a batch. |
|
|
:param num_queries_vis_abstractors: A List of List that contains the number of visual tokens for each image grid. |
|
|
:param num_queries_vis_abstractors_slow: A List of List that contains the number of visual tokens for the slow part when applying the slowfast algorithm to video frames. If the slowfast algorithm is not applied, it will have a value of None. |
|
|
:param first_last_frames_slows: A List of List that stores the only first and last frames slow mode for each sample in a batch. |
|
|
:param is_videos: A List of List that stores the boolean value indicating whether each sample in a batch is a video. |
|
|
:image_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder. |
|
|
:pixel_values_videos: A 2D tensor (torch.float32) for qwen2.5-vl visual encoder. |
|
|
:video_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder. |
|
|
:param kwargs: |
|
|
:return: |
|
|
""" |
|
|
|
|
|
|
|
|
inputs_embeds = self.model.extract_inputs_embeds( |
|
|
input_ids=input_ids, |
|
|
pixel_values=self.to_vision_model_device(pixel_values), |
|
|
image_sizes=image_sizes, |
|
|
vision_query_lengths=vision_query_lengths, |
|
|
non_vision_query_lengths=non_vision_query_lengths, |
|
|
img_start_ids_list=img_start_ids_list, |
|
|
num_queries_vis_abstractors=num_queries_vis_abstractors, |
|
|
num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow, |
|
|
first_last_frames_slows=first_last_frames_slows, |
|
|
is_videos=is_videos, |
|
|
image_grid_thw=image_grid_thw, |
|
|
pixel_values_videos=pixel_values_videos, |
|
|
video_grid_thw=video_grid_thw, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.without_llm: |
|
|
inputs_embeds = ( |
|
|
inputs_embeds.to(self.vision_model.device) if isinstance(inputs_embeds, torch.Tensor) else inputs_embeds |
|
|
) |
|
|
return inputs_embeds |
|
|
|
|
|
inputs_embeds = ( |
|
|
inputs_embeds.to(self.base_model.device) if isinstance(inputs_embeds, torch.Tensor) else inputs_embeds |
|
|
) |
|
|
|
|
|
|
|
|
pred = self.language_model.generate( |
|
|
inputs_embeds=inputs_embeds, |
|
|
pad_token_id=self.config.text_config.pad_token_id, |
|
|
eos_token_id=self.config.text_config.eos_token_id, |
|
|
bad_words_ids=[ |
|
|
[ |
|
|
self.config.text_config.bos_token_id, |
|
|
], |
|
|
[ |
|
|
self.config.text_config.eos_token_id, |
|
|
], |
|
|
], |
|
|
max_new_tokens=max_length, |
|
|
min_length=min_length, |
|
|
num_beams=num_beams, |
|
|
do_sample=False if temperature == 0.0 else do_sample, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
temperature=temperature, |
|
|
repetition_penalty=repetition_penalty, |
|
|
length_penalty=length_penalty, |
|
|
early_stopping=False if num_beams <= 1 else True, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
return pred |
|
|
|
|
|
def to_vision_model_device(self, input_tensor): |
|
|
if isinstance(input_tensor, list): |
|
|
return [self.to_vision_model_device(item) for item in input_tensor] |
|
|
elif isinstance(input_tensor, torch.Tensor): |
|
|
return input_tensor.to(self.vision_model.device) |
|
|
else: |
|
|
raise TypeError( |
|
|
"Unsupported data type. Only tensors and lists are allowed." |
|
|
) |
|
|
|
|
|
|
|
|
def get_input_embeddings(self): |
|
|
if self.without_llm: |
|
|
return None |
|
|
else: |
|
|
return self.language_model.get_input_embeddings() |
|
|
|
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.language_model.set_input_embeddings(value) |
|
|
|
|
|
|
|
|
def get_output_embeddings(self): |
|
|
if self.without_llm: |
|
|
return None |
|
|
else: |
|
|
return self.language_model.get_output_embeddings() |
|
|
|
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.language_model.set_output_embeddings(new_embeddings) |
|
|
|
|
|
|
|
|
def set_decoder(self, decoder): |
|
|
self.language_model.set_decoder(decoder) |
|
|
|
|
|
|
|
|
def get_decoder(self): |
|
|
return self.language_model.get_decoder() |
|
|
|
|
|
|
|
|
def tie_weights(self): |
|
|
if self.without_llm: |
|
|
return None |
|
|
else: |
|
|
return self.language_model.tie_weights() |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
|
|
*model_args, |
|
|
**kwargs, |
|
|
): |
|
|
model = super().from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
*model_args, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) |
|
|
return model |
|
|
|
|
|
def save_pretrained( |
|
|
self, |
|
|
save_directory: Union[str, os.PathLike], |
|
|
*args, |
|
|
**kwargs, |
|
|
): |
|
|
super().register_for_auto_class("AutoModelForCausalLM") |
|
|
self.config.register_for_auto_class() |
|
|
super().save_pretrained(save_directory, *args, **kwargs) |
|
|
self.config.architectures = ["HCXVisionV2ForCausalLM"] |
|
|
self.config.auto_map["AutoModelForCausalLM"] = "modeling_vlm.HCXVisionForCausalLM" |
|
|
self.config.auto_map["AutoModelForSequenceClassification"] = "modeling_vlm.HCXVisionForSequenceClassification" |
|
|
self.config.save_pretrained(save_directory) |
|
|
|
|
|
|
|
|
@property |
|
|
def is_qwen_visual(self): |
|
|
return self.model.is_qwen_visual |
|
|
|
|
|
@property |
|
|
def language_model(self): |
|
|
return self.model.language_model |
|
|
|
|
|
@property |
|
|
def vision_model(self): |
|
|
return self.model.vision_model |
|
|
|
|
|
@property |
|
|
def text_config(self): |
|
|
return self.model.text_config |
|
|
|
|
|
@property |
|
|
def vision_config(self): |
|
|
return self.model.vision_config |
|
|
|
|
|
@property |
|
|
def mm_projector(self): |
|
|
return self.model.mm_projector |
|
|
|
|
|
@property |
|
|
def anyres(self): |
|
|
return self.model.anyres |
|
|
|
|
|
@property |
|
|
def is_safetensor_save(self): |
|
|
return self.model.is_safetensor_save |
|
|
|
|
|
@property |
|
|
def without_llm(self): |
|
|
return self.model.without_llm |
|
|
|
|
|
@property |
|
|
def image_newline(self): |
|
|
return self.model.image_newline |
|
|
|
|
|
|
|
|
class HCXVisionForSequenceClassification(HCXVisionPreTrainedModel): |
|
|
""" |
|
|
HCX Vision model for sequence classification tasks. |
|
|
""" |
|
|
|
|
|
def __init__(self, config, **kwargs): |
|
|
super().__init__(config, without_llm=True, **kwargs) |
|
|
self.num_labels = config.num_labels if hasattr(config, "num_labels") else 2 |
|
|
self.model = HCXVisionModel(config=config, **kwargs) |
|
|
self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False) |
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = True, |
|
|
image_sizes: Optional[List[List[List[int]]]] = None, |
|
|
vision_query_lengths: Optional[List[List[int]]] = None, |
|
|
non_vision_query_lengths: Optional[List[List[int]]] = None, |
|
|
img_start_ids_list: Optional[List[List[int]]] = None, |
|
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
|
|
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, |
|
|
first_last_frames_slows: Optional[List[List[bool]]] = None, |
|
|
is_videos: Optional[List[List[bool]]] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
) -> SequenceClassifierOutputWithPast: |
|
|
""" |
|
|
Forward pass for sequence classification. |
|
|
""" |
|
|
transformer_outputs: BaseModelOutputWithPast = self.model( |
|
|
pixel_values=pixel_values, |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
image_sizes=image_sizes, |
|
|
vision_query_lengths=vision_query_lengths, |
|
|
non_vision_query_lengths=non_vision_query_lengths, |
|
|
img_start_ids_list=img_start_ids_list, |
|
|
num_queries_vis_abstractors=num_queries_vis_abstractors, |
|
|
num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow, |
|
|
first_last_frames_slows=first_last_frames_slows, |
|
|
is_videos=is_videos, |
|
|
image_grid_thw=image_grid_thw, |
|
|
pixel_values_videos=pixel_values_videos, |
|
|
video_grid_thw=video_grid_thw, |
|
|
) |
|
|
hidden_states = transformer_outputs.last_hidden_state |
|
|
logits = self.score(hidden_states) |
|
|
|
|
|
if input_ids is not None: |
|
|
batch_size = input_ids.shape[0] |
|
|
else: |
|
|
batch_size = inputs_embeds.shape[0] |
|
|
|
|
|
if self.config.pad_token_id is None and batch_size != 1: |
|
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") |
|
|
if self.config.pad_token_id is None: |
|
|
last_non_pad_token = -1 |
|
|
elif input_ids is not None: |
|
|
|
|
|
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) |
|
|
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) |
|
|
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) |
|
|
else: |
|
|
last_non_pad_token = -1 |
|
|
|
|
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) |
|
|
|
|
|
return SequenceClassifierOutputWithPast( |
|
|
loss=loss, |
|
|
logits=pooled_logits, |
|
|
past_key_values=transformer_outputs.past_key_values, |
|
|
hidden_states=transformer_outputs.hidden_states, |
|
|
attentions=transformer_outputs.attentions, |
|
|
) |
|
|
|
|
|
def save_pretrained( |
|
|
self, |
|
|
save_directory: Union[str, os.PathLike], |
|
|
*args, |
|
|
**kwargs, |
|
|
): |
|
|
super().register_for_auto_class("AutoModelForSequenceClassification") |
|
|
self.config.register_for_auto_class() |
|
|
super().save_pretrained(save_directory, *args, **kwargs) |
|
|
|
|
|
|
|
|
class HCXVisionForTokenClassification(HCXVisionPreTrainedModel): |
|
|
""" |
|
|
HCX Vision model for token classification tasks (e.g., per-token value prediction for PPO critic). |
|
|
Returns logits for each token instead of pooled output. |
|
|
""" |
|
|
|
|
|
def __init__(self, config, **kwargs): |
|
|
super().__init__(config, without_llm=True, **kwargs) |
|
|
self.num_labels = config.num_labels if hasattr(config, "num_labels") else 1 |
|
|
self.model = HCXVisionModel(config=config, **kwargs) |
|
|
|
|
|
|
|
|
if getattr(config, "classifier_dropout", None) is not None: |
|
|
classifier_dropout = config.classifier_dropout |
|
|
elif getattr(config.text_config, "hidden_dropout", None) is not None: |
|
|
classifier_dropout = config.text_config.hidden_dropout |
|
|
else: |
|
|
classifier_dropout = 0.1 |
|
|
self.dropout = nn.Dropout(classifier_dropout) |
|
|
|
|
|
|
|
|
self.score = nn.Linear(config.text_config.hidden_size, self.num_labels, bias=False) |
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = True, |
|
|
image_sizes: Optional[List[List[List[int]]]] = None, |
|
|
vision_query_lengths: Optional[List[List[int]]] = None, |
|
|
non_vision_query_lengths: Optional[List[List[int]]] = None, |
|
|
img_start_ids_list: Optional[List[List[int]]] = None, |
|
|
num_queries_vis_abstractors: Optional[List[List[int]]] = None, |
|
|
num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, |
|
|
first_last_frames_slows: Optional[List[List[bool]]] = None, |
|
|
is_videos: Optional[List[List[bool]]] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
) -> TokenClassifierOutput: |
|
|
""" |
|
|
Forward pass for token classification. |
|
|
|
|
|
Returns: |
|
|
TokenClassifierOutput with logits of shape [batch_size, sequence_length, num_labels] |
|
|
""" |
|
|
transformer_outputs: BaseModelOutputWithPast = self.model( |
|
|
pixel_values=pixel_values, |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
image_sizes=image_sizes, |
|
|
vision_query_lengths=vision_query_lengths, |
|
|
non_vision_query_lengths=non_vision_query_lengths, |
|
|
img_start_ids_list=img_start_ids_list, |
|
|
num_queries_vis_abstractors=num_queries_vis_abstractors, |
|
|
num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow, |
|
|
first_last_frames_slows=first_last_frames_slows, |
|
|
is_videos=is_videos, |
|
|
image_grid_thw=image_grid_thw, |
|
|
pixel_values_videos=pixel_values_videos, |
|
|
video_grid_thw=video_grid_thw, |
|
|
) |
|
|
|
|
|
|
|
|
hidden_states = transformer_outputs.last_hidden_state |
|
|
|
|
|
|
|
|
logits = self.score(hidden_states) |
|
|
|
|
|
return TokenClassifierOutput( |
|
|
loss=None, |
|
|
logits=logits, |
|
|
hidden_states=transformer_outputs.hidden_states, |
|
|
attentions=transformer_outputs.attentions, |
|
|
) |
|
|
|
|
|
def save_pretrained( |
|
|
self, |
|
|
save_directory: Union[str, os.PathLike], |
|
|
*args, |
|
|
**kwargs, |
|
|
): |
|
|
super().register_for_auto_class("AutoModelForTokenClassification") |
|
|
self.config.register_for_auto_class() |
|
|
super().save_pretrained(save_directory, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
class VLM_Mlp(nn.Module): |
|
|
"""MLP as used in Vision Transformer, MLP-Mixer and related networks""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
mm_projector_type, |
|
|
in_features, |
|
|
hidden_features=None, |
|
|
out_features=None, |
|
|
act_layer=nn.GELU, |
|
|
): |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
self.mm_projector_type = mm_projector_type |
|
|
if self.mm_projector_type == "mlp": |
|
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
|
self.act = act_layer() |
|
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
|
elif self.mm_projector_type == "inverted_mlp": |
|
|
self.fc1 = nn.Linear(in_features, 2 * hidden_features) |
|
|
self.act = act_layer() |
|
|
self.fc2 = nn.Linear(2 * hidden_features, out_features) |
|
|
else: |
|
|
raise NotImplementedError("{} is not implemented".format(self.mm_projector_type)) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.fc1(x) |
|
|
x = self.act(x) |
|
|
x = self.fc2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Projector(nn.Module): |
|
|
"""Base projector class""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_queries: int, |
|
|
num_input_tokens: int, |
|
|
encoder_hidden_size: int, |
|
|
hidden_size: int, |
|
|
output_hidden_size: int, |
|
|
pos_emb=True, |
|
|
prenorm=False, |
|
|
): |
|
|
super().__init__() |
|
|
self.num_input_tokens = num_input_tokens |
|
|
self.output_hidden_size = output_hidden_size |
|
|
|
|
|
|
|
|
if pos_emb: |
|
|
self.pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, encoder_hidden_size)) |
|
|
|
|
|
self.pos_emb.data.normal_(mean=0.0, std=0.02) |
|
|
else: |
|
|
self.pos_emb = None |
|
|
|
|
|
if prenorm: |
|
|
self.prenorm = LayerNorm(encoder_hidden_size) |
|
|
else: |
|
|
self.prenorm = None |
|
|
|
|
|
self.build_net(num_queries, encoder_hidden_size, hidden_size, output_hidden_size) |
|
|
|
|
|
def build_net(self): |
|
|
raise NotImplementedError() |
|
|
|
|
|
def _forward( |
|
|
self, |
|
|
x, |
|
|
num_queries_vis_abstractors: Optional[List[int]] = None, |
|
|
num_grids: Optional[List[int]] = None, |
|
|
freeze_before_sampler: bool = False, |
|
|
): |
|
|
raise NotImplementedError() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
num_queries_vis_abstractors: Optional[List[int]] = None, |
|
|
num_grids: Optional[List[int]] = None, |
|
|
freeze_before_sampler: bool = False, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: (B, L, encoder_hidden_size) tensor from the visual backbone (CLIP visual encoder), including cls token. |
|
|
""" |
|
|
if self.prenorm is not None: |
|
|
x = self.prenorm(x) |
|
|
|
|
|
if self.pos_emb is not None: |
|
|
x = x + self.pos_emb |
|
|
|
|
|
x = self._forward( |
|
|
x, |
|
|
num_queries_vis_abstractors=num_queries_vis_abstractors, |
|
|
num_grids=num_grids, |
|
|
freeze_before_sampler=freeze_before_sampler, |
|
|
) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class ConvProjector(Projector): |
|
|
def _forward( |
|
|
self, |
|
|
x, |
|
|
num_queries_vis_abstractors: Optional[List[int]] = None, |
|
|
num_grids: Optional[List[int]] = None, |
|
|
freeze_before_sampler: bool = False, |
|
|
): |
|
|
|
|
|
hw = int(x.size(1) ** 0.5) |
|
|
x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw) |
|
|
|
|
|
if num_queries_vis_abstractors is not None: |
|
|
assert num_grids is not None |
|
|
|
|
|
return self._forward_adaptive_num_query(x, num_queries_vis_abstractors, num_grids, freeze_before_sampler) |
|
|
|
|
|
if freeze_before_sampler: |
|
|
with torch.no_grad(): |
|
|
x = self.net[0](x) |
|
|
x = self.net[1](x) |
|
|
x = self.net[2](x) |
|
|
else: |
|
|
x = self.net(x) |
|
|
x = rearrange(x, "b d h w -> b (h w) d") |
|
|
x = self.readout(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def _forward_adaptive_num_query( |
|
|
self, |
|
|
x, |
|
|
num_queries_vis_abstractors: Optional[List[int]] = None, |
|
|
num_grids: Optional[List[int]] = None, |
|
|
freeze_before_sampler: bool = False, |
|
|
): |
|
|
|
|
|
|
|
|
assert len(self.net) == 3 |
|
|
|
|
|
if freeze_before_sampler: |
|
|
with torch.no_grad(): |
|
|
x = self.net[0](x) |
|
|
else: |
|
|
x = self.net[0](x) |
|
|
|
|
|
new_x = [] |
|
|
for i, num_queries in enumerate(num_queries_vis_abstractors): |
|
|
hw = int(num_queries**0.5) |
|
|
sampler = nn.AdaptiveAvgPool2d((hw, hw)) |
|
|
out = sampler(x[num_grids[i] : num_grids[i + 1], :]) |
|
|
out = self.net[2](out) |
|
|
|
|
|
out = rearrange(out, "b d h w -> b (h w) d") |
|
|
out = self.readout(out) |
|
|
|
|
|
new_x.append(out) |
|
|
|
|
|
return new_x |
|
|
|
|
|
|
|
|
class CAbstractor(ConvProjector): |
|
|
"""C-Abstractor""" |
|
|
|
|
|
def build_net(self, n_queries, encoder_hidden_size, hidden_size, output_hidden_size, depth=3, mlp_depth=2): |
|
|
assert (n_queries**0.5).is_integer(), "n_queries must be square number" |
|
|
hw = int(n_queries**0.5) |
|
|
|
|
|
|
|
|
RegBlock = partial( |
|
|
RegStage, |
|
|
stride=1, |
|
|
dilation=1, |
|
|
act_layer=nn.SiLU, |
|
|
norm_layer=LayerNorm2d, |
|
|
) |
|
|
|
|
|
s1 = RegBlock( |
|
|
depth, |
|
|
encoder_hidden_size, |
|
|
hidden_size, |
|
|
) |
|
|
sampler = nn.AdaptiveAvgPool2d((hw, hw)) |
|
|
s2 = RegBlock( |
|
|
depth, |
|
|
hidden_size, |
|
|
hidden_size, |
|
|
) |
|
|
|
|
|
self.net = nn.Sequential(s1, sampler, s2) |
|
|
|
|
|
self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size) |
|
|
|
|
|
def build_mlp(self, depth, hidden_size, output_hidden_size): |
|
|
layers = [nn.Linear(hidden_size, output_hidden_size)] |
|
|
for _ in range(1, depth): |
|
|
layers.append(nn.SiLU()) |
|
|
layers.append(nn.Linear(output_hidden_size, output_hidden_size)) |
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
|
|
|
AutoConfig.register("vlm", HCXVisionConfig) |
|
|
try: |
|
|
from .configuration_hyperclovax import HyperCLOVAXConfig |
|
|
from .modeling_hyperclovax import HyperCLOVAXForCausalLM |
|
|
|
|
|
AutoConfig.register("hyperclovax", HyperCLOVAXConfig) |
|
|
AutoModelForCausalLM.register( |
|
|
HyperCLOVAXConfig, |
|
|
HyperCLOVAXForCausalLM, |
|
|
) |
|
|
except: |
|
|
pass |
|
|
|