PengLiu
push inference code
56ef371
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import Qwen2_5_VLConfig, AutoConfig, AutoModelForCausalLM
from vlm_fo1.model.multimodal_encoder.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel, Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLCausalLMOutputWithPast
from vlm_fo1.model.multimodal_encoder.qwen2_5_vl_encoder import Qwen2_5_VlVisionTower
from vlm_fo1.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_REGION_INDEX, QWEN2_5_VL_IMAGE_TOKEN, QWEN2_5_VL_IMAGE_TOKEN_INDEX
from ..omchat_arch import OmChatMetaModel, OmChatMetaForCausalLM
# Custom config which extends Qwen2_5_VLConfig for OmChat multimodal model
class OmChatQwen25VLConfig(Qwen2_5_VLConfig):
model_type = "omchat_qwen2_5_vl"
rotary_type = "normal_rotary"
multi_scale_im = None
vision_tower_aux = None
# Core model definition: inherits from OmChat and Qwen multimodal base
class OmChatQwen25VLModel(OmChatMetaModel, Qwen2_5_VLModel):
config_class = OmChatQwen25VLConfig
def __init__(self, config: Qwen2_5_VLConfig):
super(OmChatQwen25VLModel, self).__init__(config)
# Main class for multimodal CausalLM
class OmChatQwen25VLForCausalLM(Qwen2_5_VLForConditionalGeneration, OmChatMetaForCausalLM):
config_class = OmChatQwen25VLConfig
def __init__(self, config, delay_load=True):
# Ensure config has delay_load property
if not hasattr(config, 'delay_load'):
config.delay_load = delay_load
super(Qwen2_5_VLForConditionalGeneration, self).__init__(config)
self.model = OmChatQwen25VLModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.rope_deltas = None # cache rope_deltas here
self.post_init()
# Encode input images into feature representations
def encode_images(self, images, images_grid_thw=None):
# If vision_tower is Qwen2.5-specific, use its custom forward signature
if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
image_features = self.get_model().get_vision_tower()(images, images_grid_thw)
image_features, image_grid_thws, multi_level_features = image_features
# If multiple images, handle concatenation
if type(image_features) is list:
# List has items of shape (1, seq_len, dim)
token_length_list = [i.shape[1] for i in image_features]
image_features = torch.cat(image_features, dim=1) # Concatenate to (1, total_seq_len, dim)
else:
image_features = self.get_model().get_vision_tower()(images)
image_grid_thws = None
multi_level_features = None
image_features = self.get_model().mm_projector(image_features)
# Split concatenated image features back by original lengths (for multi-image case)
if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
start = 0
new_image_features = []
# Split according to token_length_list
for length in token_length_list:
end = start + length
new_image_features.append(image_features[:, start:end, :].squeeze(0))
start = end
image_features = new_image_features
return image_features, image_grid_thws, multi_level_features
# Encode region regions (bounding boxes) into features, optionally using auxiliary vision tower
def encode_regions(self, images, bbox_list, vt_multi_level_features=None, vt_images_size=None):
aux_image_features_list = self.get_model().get_vision_tower_aux()(images)
region_features = []
if getattr(self.config, "mm_use_vision_tower_region_feature", False):
image_features_list = vt_multi_level_features
for batch_idx, (image_features, aux_image_features) in enumerate(zip(image_features_list, aux_image_features_list)):
if getattr(self.config, "mm_use_simpleFPN_for_vt", False):
multilevel_visual_feats = image_features[-1]
else:
multilevel_visual_feats = image_features
multilevel_aux_visual_feats = aux_image_features["image_features"]
boxes = bbox_list[batch_idx]
# If no boxes provided, use dummy box (covers tiny region)
if boxes is None or len(boxes) == 0:
boxes = torch.tensor([[0, 10, 0, 10]], device=multilevel_aux_visual_feats[0].device, dtype=torch.float32)
boxes = boxes.to(torch.float32).to(multilevel_aux_visual_feats[0].device)
current_image_height, current_image_width = images[batch_idx].shape[-2:]
original_height, original_width = vt_images_size[batch_idx]
# Scale bounding boxes from original image size to processed size
scale_height = original_height / current_image_height
scale_width = original_width / current_image_width
vt_boxes = boxes * torch.tensor([scale_width, scale_height, scale_width, scale_height], device=boxes.device)
extracted_region_feat = self.get_model().object_vp_extractor(
aux_multi_level_features=multilevel_aux_visual_feats,
vt_multi_level_features=multilevel_visual_feats,
aux_boxes=[boxes],
vt_boxes=[vt_boxes]
).squeeze(0).to(multilevel_aux_visual_feats[0].dtype)
region_feat = self.get_model().mm_projector_aux(extracted_region_feat) # [num_bbox, 2048]
region_features.append(region_feat)
else:
# Extract region features only from auxiliary vision tower
for batch_idx, image_features in enumerate(aux_image_features_list):
multilevel_visual_feats = image_features["image_features"]
last_feat = image_features["last_feat"]
boxes = bbox_list[batch_idx]
if boxes is None or len(boxes) == 0:
boxes = torch.tensor([[0, 10, 0, 10]], device=multilevel_visual_feats[0].device, dtype=torch.float32)
multi_level_aux_features = multilevel_visual_feats
boxes = boxes.to(torch.float32).to(multi_level_aux_features[0].device)
extracted_region_feat = self.get_model().object_vp_extractor(
multi_level_aux_features,
[boxes],
).squeeze(0).to(multi_level_aux_features[0].dtype)
region_feat = self.get_model().mm_projector_aux(extracted_region_feat) # [num_bbox, 2880]
region_features.append(region_feat)
return region_features
def get_model(self):
# Getter for model. Used to access backbone/model internals.
return self.model
# Convert sequence of input_ids/labels/images/boxes to multimodal embedding and associated masks/ids for transformer input.
def prepare_inputs_labels_for_qwen2_5_vl_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels, images, images_aux=None, bbox_list=None, image_grid_thws=None
):
# ========================== Above this line, input parsing and batching =============================
vision_tower = self.get_vision_tower()
video_tower = self.get_video_tower()
vision_tower_aux = self.get_vision_tower_aux()
# Fast-path for non-multimodal case or first step in generation (i.e. only one token in input)
if (vision_tower is None and video_tower is None) or images is None or input_ids.shape[1] == 1:
if past_key_values is not None and (vision_tower is not None or video_tower is not None) and images is not None and input_ids.shape[1] == 1:
target_shape = past_key_values[-1][-1].shape[-2] + 1
attention_mask = torch.cat((attention_mask, torch.ones(
(attention_mask.shape[0], target_shape - attention_mask.shape[1]),
dtype=attention_mask.dtype,
device=attention_mask.device
)), dim=1)
position_ids=None
cache_position = torch.tensor([target_shape - 1],device=attention_mask.device)
return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, cache_position
# Indices for images (3D or 2D tensors) and videos (4D tensors)
image_idx = [idx for idx, img in enumerate(images) if img.ndim == 3 or img.ndim == 2]
is_all_image = len(image_idx) == len(images)
video_idx = [idx for idx, vid in enumerate(images) if vid.ndim == 4]
# Stack image and video tensors accordingly for mini-batch processing
if isinstance(vision_tower, Qwen2_5_VlVisionTower):
images_minibatch = [images[idx] for idx in image_idx] if len(image_idx) > 0 else [] # list of [c,h,w], can have variable shapes
else:
images_minibatch = torch.stack([images[idx] for idx in image_idx]) if len(image_idx) > 0 else [] # tensor [mini_b, c, h, w]
videos_minibatch = torch.stack([images[idx] for idx in video_idx]) if len(video_idx) > 0 else [] # tensor [mini_b, c, t, h, w]
# Auxiliary batch for region encoding, if relevant
if vision_tower_aux is not None and images_aux is not None:
images_minibatch_aux = [images_aux[idx].unsqueeze(0) for idx in image_idx] if len(image_idx) > 0 else [] # list of [1, c, h, w]
# tmp_image_features will be indexed to scatter extracted image/video features into original batch positions
tmp_image_features = [None] * (len(image_idx) + len(video_idx))
if getattr(images_minibatch, 'ndim', 0) == 4 or (type(images_minibatch) is list and len(images_minibatch) > 0): # batch consists of images, [mini_b, c, h, w]
if vision_tower is not None:
image_features_minibatch, image_grid_thws_minibatch, vt_multi_level_features_minibatch = self.encode_images(images_minibatch, image_grid_thws) # [mini_b, l, c]
else:
image_features_minibatch = torch.randn(1).to(self.device) # dummy feature for video-only training under tuning
# Map extracted image features back to their places in the original batch
for i, pos in enumerate(image_idx):
tmp_image_features[pos] = image_features_minibatch[i]
# Handle auxiliary region features if enabled and boxes provided
if vision_tower_aux is not None and bbox_list is not None and len(bbox_list) > 0:
if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
patch_size = self.get_model().get_vision_tower().config.patch_size
vt_images_size_minibatch = [im_grid_thw[0][-2:]*patch_size for im_grid_thw in image_grid_thws]
region_features = self.encode_regions(images_minibatch_aux, bbox_list, vt_multi_level_features_minibatch, vt_images_size_minibatch) # [mini_b, l, c]
else:
region_features = None
# Same as above, but for video features if any
if getattr(videos_minibatch, 'ndim', 0) == 5: # batch consists of videos, [mini_b, c, t, h, w]
video_features_minibatch = self.encode_videos(videos_minibatch) # fake list [mini_b, t, l, c]
for i, pos in enumerate(video_idx):
tmp_image_features[pos] = video_features_minibatch[i]
# Flatten image feature slot list to proper order for current batch
new_tmp = []
for image in tmp_image_features:
# If multi-image per item, flatten out
if isinstance(image, list):
t = len(image)
for i in range(t):
new_tmp.append(image[i])
else:
new_tmp.append(image)
image_features = new_tmp
# =========================== Now, build multimodal input & target sequences =========================
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
raise NotImplementedError
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
# Default construction of masks etc.
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# For each batch item, strip padded tokens based on attention_mask
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
# If neither region auxiliary nor bboxes present: process classic image-text input
if vision_tower_aux is None and (bbox_list is None or all(x is None for x in bbox_list)):
new_input_embeds = []
new_labels = []
new_input_ids = []
cur_image_idx = 0
image_nums_in_batch = []
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
image_nums_in_batch.append(num_images)
# If there are no image markers, just get text features
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
new_input_ids.append(cur_input_ids)
cur_image_idx += 1
continue
# Split on image token indices: replace them with image features after conversion
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
cur_new_input_ids = []
for i in range(num_images + 1):
# Interleave text and image features
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
cur_new_input_ids.append(cur_input_ids_noim[i])
if i < num_images:
cur_image_features = image_features[cur_image_idx].to(self.device)
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
cur_new_input_ids.append(torch.full((cur_image_features.shape[0],), self.config.image_token_id, device=cur_labels.device, dtype=cur_labels.dtype))
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
cur_new_input_ids = torch.cat(cur_new_input_ids)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
new_input_ids.append(cur_new_input_ids)
# If region markers or region features enabled in config
else:
new_input_embeds = []
new_labels = []
new_input_ids = []
cur_image_idx = 0
image_nums_in_batch = []
for batch_idx, cur_input_ids in enumerate(input_ids):
cur_region_idx = 0
# Detect image and region special token counts
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
num_regions = (cur_input_ids == DEFAULT_REGION_INDEX).sum() if DEFAULT_REGION_INDEX in cur_input_ids else 0
image_nums_in_batch.append(num_images)
# If no markers, just do text embedding for this item
if num_images == 0 and num_regions == 0:
cur_image_features = image_features[cur_image_idx]
cur_region_features = region_features[cur_region_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_region_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
new_input_ids.append(cur_input_ids)
cur_image_idx += 1
continue
# Get all special marker indices (image/region)
image_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
region_indices = torch.where(cur_input_ids == DEFAULT_REGION_INDEX)[0].tolist() if num_regions > 0 else []
all_special_indices = sorted([-1] + image_indices + region_indices + [cur_input_ids.shape[0]])
# Split out plain text chunks between special markers
cur_input_ids_segments = []
cur_labels = labels[batch_idx]
cur_labels_segments = []
for i in range(len(all_special_indices) - 1):
cur_input_ids_segments.append(cur_input_ids[all_special_indices[i]+1:all_special_indices[i+1]])
cur_labels_segments.append(cur_labels[all_special_indices[i]+1:all_special_indices[i+1]])
# Project text ids to word embeddings
split_sizes = [x.shape[0] for x in cur_labels_segments]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_segments))
if num_regions == 0 and vision_tower_aux is not None and region_features is not None:
cur_region_features = region_features[cur_region_idx]
temp_input_embeds = torch.cat([cur_input_embeds, cur_region_features[0:0]], dim=0)
cur_input_embeds = temp_input_embeds
cur_input_embeds_segments = torch.split(cur_input_embeds, split_sizes, dim=0)
# Reassemble text and image/region segments in order
cur_new_input_embeds = []
cur_new_labels = []
cur_new_input_ids = []
for i in range(len(all_special_indices) - 1):
# Insert current text segment
cur_new_input_embeds.append(cur_input_embeds_segments[i])
cur_new_labels.append(cur_labels_segments[i])
cur_new_input_ids.append(cur_input_ids_segments[i])
# If next is image, insert feature representation
if all_special_indices[i+1] in image_indices:
cur_image_features = image_features[cur_image_idx].to(self.device)
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
cur_new_input_ids.append(torch.full((cur_image_features.shape[0],), self.config.image_token_id, device=cur_labels.device, dtype=cur_labels.dtype))
# If next is region token, insert extracted region features
elif all_special_indices[i+1] in region_indices:
cur_region_features = region_features[batch_idx][cur_region_idx].to(self.device).unsqueeze(0)
cur_region_idx += 1
cur_new_input_embeds.append(cur_region_features)
cur_new_labels.append(torch.full((cur_region_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
cur_new_input_ids.append(torch.full((cur_region_features.shape[0],), DEFAULT_REGION_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
# Combine for this batch item
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
cur_new_input_ids = torch.cat(cur_new_input_ids)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
new_input_ids.append(cur_new_input_ids)
# Truncate sequences to maximum model length, if image+region tokens caused overflow
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Pad sequences in the batch to same length; compute batch masks
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
new_input_ids_padded = torch.full((batch_size, max_len), self.config.bos_token_id, dtype=new_input_ids[0].dtype, device=new_input_ids[0].device)
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
# Left or right padding as per config; fill padded tensors
for i, (cur_new_embed, cur_new_labels, cur_new_input_ids) in enumerate(zip(new_input_embeds, new_labels, new_input_ids)):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
# Left pad: add zeros before text tokens/features
new_input_embeds_padded.append(torch.cat((
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
cur_new_embed
), dim=0))
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
else:
# Right pad: add zeros after text tokens/features
new_input_embeds_padded.append(torch.cat((
cur_new_embed,
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
), dim=0))
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
new_input_ids_padded[i, :cur_len] = cur_new_input_ids
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
new_input_ids = new_input_ids_padded
# Only set new_labels if original labels were not None
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
# Similarly handle provided attention_mask/position_ids overrides
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
# For Qwen2.5 vision towers, use and concatenate image_grid_thws for positional computations
if isinstance(self.get_model().get_vision_tower(), Qwen2_5_VlVisionTower):
image_grid_thws = []
cur_image_idx = 0
for num_images in image_nums_in_batch:
if num_images == 0:
cur_image_idx += 1
continue
image_grid_thws += image_grid_thws_minibatch[cur_image_idx:cur_image_idx+num_images]
cur_image_idx += num_images
if len(image_grid_thws) > 0:
image_grid_thws = torch.cat(image_grid_thws, dim=0)
else:
image_grid_thws = None
rope_index_kwargs = {
"input_ids": new_input_ids,
"image_grid_thw": image_grid_thws,
"video_grid_thw": None,
"attention_mask": attention_mask,
}
# Compute new position_ids and rope_deltas for transformer (for rotary embeddings)
position_ids, rope_deltas = self.get_rope_index(**rope_index_kwargs)
cache_position = torch.arange(new_input_embeds.shape[1], device=new_input_embeds.device)
else:
rope_deltas = None
cache_position = None
# Final output is a tuple mimicking HuggingFace prepare_inputs_for_generation return
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, rope_deltas, cache_position
# Patch forward() of HF CausalLM to allow multimodal embedding with images/regions
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
images: Optional[torch.FloatTensor] = None,
images_aux: Optional[torch.FloatTensor] = None,
bbox_list: Optional[torch.FloatTensor] = None,
image_grid_thws: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
rope_deltas,
cache_position
) = self.prepare_inputs_labels_for_qwen2_5_vl_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
images_aux,
bbox_list,
image_grid_thws
)
if rope_deltas is not None:
self.rope_deltas = rope_deltas
# Call base CausalLM forward, with possibly replaced multimodal embeddings
out = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
rope_deltas=rope_deltas,
cache_position=cache_position,
second_per_grid_ts=second_per_grid_ts,
return_dict=return_dict
)
return out
# Prepare model input dict for autoregressive generation (for use with generation methods like generate())
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
second_per_grid_ts=None,
images: Optional[torch.FloatTensor] = None,
images_aux: Optional[torch.FloatTensor] = None,
bbox_list: Optional[torch.FloatTensor] = None,
image_grid_thws: Optional[torch.FloatTensor] = None,
**kwargs,
):
# Wrap parent logic so extra multimodal kwargs are preserved
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
images=images,
images_aux=images_aux,
bbox_list=bbox_list,
image_grid_thws=image_grid_thws,
)
return model_inputs
# Register our config and model with HuggingFace transformers registry
AutoConfig.register("omchat_qwen2_5_vl", OmChatQwen25VLConfig)
AutoModelForCausalLM.register(OmChatQwen25VLConfig, OmChatQwen25VLForCausalLM)