|
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast |
|
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast |
|
|
import torch |
|
|
from typing import Optional, List, Union, Tuple |
|
|
from torch.nn import CrossEntropyLoss |
|
|
import numpy as np |
|
|
import transformers.models.qwen2_vl.modeling_qwen2_vl |
|
|
import transformers.models.qwen2_5_vl.modeling_qwen2_5_vl |
|
|
from liger_kernel.transformers.fused_linear_cross_entropy import ( |
|
|
LigerFusedLinearCrossEntropyLoss |
|
|
) |
|
|
|
|
|
def replace_qwen_2_with_mixed_modality_forward(use_liger=True): |
|
|
if use_liger: |
|
|
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen_2_mixed_modality_forward_with_flce |
|
|
else: |
|
|
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen_2_mixed_modality_forward |
|
|
|
|
|
def replace_qwen2_5_with_mixed_modality_forward(use_liger=True): |
|
|
if use_liger: |
|
|
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_mixed_modality_forward_with_flce |
|
|
else: |
|
|
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_mixed_modality_forward |
|
|
|
|
|
def qwen_2_mixed_modality_forward_with_flce( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
pixel_values: Optional[torch.Tensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
rope_deltas: Optional[torch.LongTensor] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
): |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.model.embed_tokens(input_ids) |
|
|
|
|
|
|
|
|
if pixel_values is None and pixel_values_videos is None: |
|
|
|
|
|
dummy_pixel = torch.zeros(14308, 1176).to(self.visual.get_device()) |
|
|
dummy_grid = torch.tensor([[1, 98, 146]]).to(self.visual.get_device()) |
|
|
|
|
|
dummy_pixel = dummy_pixel.type(self.visual.get_dtype()) |
|
|
image_embeds = self.visual(dummy_pixel, grid_thw=dummy_grid) |
|
|
|
|
|
|
|
|
|
|
|
inputs_embeds += image_embeds.mean() * 0 |
|
|
|
|
|
if pixel_values is not None: |
|
|
pixel_values = pixel_values.type(self.visual.get_dtype()) |
|
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
|
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item() |
|
|
n_image_features = image_embeds.shape[0] |
|
|
if n_image_tokens != n_image_features: |
|
|
raise ValueError( |
|
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
|
|
) |
|
|
image_mask = ( |
|
|
(input_ids == self.config.image_token_id) |
|
|
.unsqueeze(-1) |
|
|
.expand_as(inputs_embeds) |
|
|
.to(inputs_embeds.device) |
|
|
) |
|
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
|
|
|
if pixel_values_videos is not None: |
|
|
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) |
|
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) |
|
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item() |
|
|
n_video_features = video_embeds.shape[0] |
|
|
if n_video_tokens != n_video_features: |
|
|
raise ValueError( |
|
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
|
|
) |
|
|
video_mask = ( |
|
|
(input_ids == self.config.video_token_id) |
|
|
.unsqueeze(-1) |
|
|
.expand_as(inputs_embeds) |
|
|
.to(inputs_embeds.device) |
|
|
) |
|
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.to(inputs_embeds.device) |
|
|
|
|
|
|
|
|
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): |
|
|
|
|
|
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: |
|
|
position_ids, rope_deltas = self.get_rope_index( |
|
|
input_ids, image_grid_thw, video_grid_thw, attention_mask |
|
|
) |
|
|
self.rope_deltas = rope_deltas |
|
|
|
|
|
else: |
|
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
|
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 |
|
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
|
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
|
|
if cache_position is not None: |
|
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
|
|
position_ids = position_ids.add(delta) |
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=None, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
|
|
|
loss = None |
|
|
logits = None |
|
|
|
|
|
if self.training and (labels is not None): |
|
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
|
|
|
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
|
lce = LigerFusedLinearCrossEntropyLoss() |
|
|
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) |
|
|
else: |
|
|
logits = self.lm_head(hidden_states) |
|
|
if labels is not None: |
|
|
|
|
|
logits = logits.float() |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return Qwen2VLCausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
rope_deltas=self.rope_deltas, |
|
|
) |
|
|
|
|
|
def qwen_2_mixed_modality_forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
pixel_values: Optional[torch.Tensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
rope_deltas: Optional[torch.LongTensor] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
): |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.model.embed_tokens(input_ids) |
|
|
|
|
|
|
|
|
if pixel_values is None and pixel_values_videos is None: |
|
|
|
|
|
dummy_pixel = torch.zeros(14308, 1176).to(self.visual.get_device()) |
|
|
dummy_grid = torch.tensor([[1, 98, 146]]).to(self.visual.get_device()) |
|
|
|
|
|
dummy_pixel = dummy_pixel.type(self.visual.get_dtype()) |
|
|
image_embeds = self.visual(dummy_pixel, grid_thw=dummy_grid) |
|
|
|
|
|
|
|
|
|
|
|
inputs_embeds += image_embeds.mean() * 0 |
|
|
|
|
|
if pixel_values is not None: |
|
|
pixel_values = pixel_values.type(self.visual.get_dtype()) |
|
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
|
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item() |
|
|
n_image_features = image_embeds.shape[0] |
|
|
if n_image_tokens != n_image_features: |
|
|
raise ValueError( |
|
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
|
|
) |
|
|
image_mask = ( |
|
|
(input_ids == self.config.image_token_id) |
|
|
.unsqueeze(-1) |
|
|
.expand_as(inputs_embeds) |
|
|
.to(inputs_embeds.device) |
|
|
) |
|
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
|
|
|
if pixel_values_videos is not None: |
|
|
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) |
|
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) |
|
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item() |
|
|
n_video_features = video_embeds.shape[0] |
|
|
if n_video_tokens != n_video_features: |
|
|
raise ValueError( |
|
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
|
|
) |
|
|
video_mask = ( |
|
|
(input_ids == self.config.video_token_id) |
|
|
.unsqueeze(-1) |
|
|
.expand_as(inputs_embeds) |
|
|
.to(inputs_embeds.device) |
|
|
) |
|
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.to(inputs_embeds.device) |
|
|
|
|
|
|
|
|
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): |
|
|
|
|
|
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: |
|
|
position_ids, rope_deltas = self.get_rope_index( |
|
|
input_ids, image_grid_thw, video_grid_thw, attention_mask |
|
|
) |
|
|
self.rope_deltas = rope_deltas |
|
|
|
|
|
else: |
|
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
|
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 |
|
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
|
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
|
|
if cache_position is not None: |
|
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
|
|
position_ids = position_ids.add(delta) |
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=None, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
logits = logits.float() |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return Qwen2VLCausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
rope_deltas=self.rope_deltas, |
|
|
) |
|
|
|
|
|
def qwen2_5_mixed_modality_forward_with_flce( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
pixel_values: Optional[torch.Tensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
rope_deltas: Optional[torch.LongTensor] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
second_per_grid_ts: Optional[torch.Tensor] = None, |
|
|
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.model.embed_tokens(input_ids) |
|
|
|
|
|
|
|
|
if pixel_values is None and pixel_values_videos is None: |
|
|
|
|
|
dummy_pixel = torch.zeros(14308, 1176).to(self.visual.device) |
|
|
dummy_grid = torch.tensor([[1, 98, 146]]).to(self.visual.device) |
|
|
|
|
|
dummy_pixel = dummy_pixel.type(self.visual.dtype) |
|
|
image_embeds = self.visual(dummy_pixel, grid_thw=dummy_grid) |
|
|
|
|
|
|
|
|
|
|
|
inputs_embeds += image_embeds.mean() * 0 |
|
|
|
|
|
if pixel_values is not None: |
|
|
pixel_values = pixel_values.type(self.visual.dtype) |
|
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
|
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item() |
|
|
n_image_features = image_embeds.shape[0] |
|
|
if n_image_tokens != n_image_features: |
|
|
raise ValueError( |
|
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
|
|
) |
|
|
|
|
|
mask = input_ids == self.config.image_token_id |
|
|
mask_unsqueezed = mask.unsqueeze(-1) |
|
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
|
|
image_mask = mask_expanded.to(inputs_embeds.device) |
|
|
|
|
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
|
|
|
if pixel_values_videos is not None: |
|
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype) |
|
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) |
|
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item() |
|
|
n_video_features = video_embeds.shape[0] |
|
|
if n_video_tokens != n_video_features: |
|
|
raise ValueError( |
|
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
|
|
) |
|
|
|
|
|
mask = input_ids == self.config.video_token_id |
|
|
mask_unsqueezed = mask.unsqueeze(-1) |
|
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
|
|
video_mask = mask_expanded.to(inputs_embeds.device) |
|
|
|
|
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.to(inputs_embeds.device) |
|
|
|
|
|
|
|
|
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): |
|
|
|
|
|
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: |
|
|
position_ids, rope_deltas = self.get_rope_index( |
|
|
input_ids, |
|
|
image_grid_thw, |
|
|
video_grid_thw, |
|
|
second_per_grid_ts, |
|
|
attention_mask, |
|
|
) |
|
|
self.rope_deltas = rope_deltas |
|
|
|
|
|
else: |
|
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
|
delta = ( |
|
|
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device) |
|
|
if cache_position is not None |
|
|
else 0 |
|
|
) |
|
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
|
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
|
|
if cache_position is not None: |
|
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
|
|
position_ids = position_ids.add(delta) |
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=None, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
|
|
|
loss = None |
|
|
logits = None |
|
|
|
|
|
if self.training and (labels is not None): |
|
|
shift_hidden_states = hidden_states[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
|
|
|
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
|
lce = LigerFusedLinearCrossEntropyLoss() |
|
|
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) |
|
|
else: |
|
|
logits = self.lm_head(hidden_states) |
|
|
if labels is not None: |
|
|
|
|
|
logits = logits.float() |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return Qwen2_5_VLCausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
rope_deltas=self.rope_deltas, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def qwen2_5_mixed_modality_forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
pixel_values: Optional[torch.Tensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
rope_deltas: Optional[torch.LongTensor] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
second_per_grid_ts: Optional[torch.Tensor] = None, |
|
|
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.model.embed_tokens(input_ids) |
|
|
|
|
|
|
|
|
if pixel_values is None and pixel_values_videos is None: |
|
|
|
|
|
dummy_pixel = torch.zeros(14308, 1176).to(self.visual.device) |
|
|
dummy_grid = torch.tensor([[1, 98, 146]]).to(self.visual.device) |
|
|
|
|
|
dummy_pixel = dummy_pixel.type(self.visual.dtype) |
|
|
image_embeds = self.visual(dummy_pixel, grid_thw=dummy_grid) |
|
|
|
|
|
|
|
|
|
|
|
inputs_embeds += image_embeds.mean() * 0 |
|
|
|
|
|
if pixel_values is not None: |
|
|
pixel_values = pixel_values.type(self.visual.dtype) |
|
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
|
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item() |
|
|
n_image_features = image_embeds.shape[0] |
|
|
if n_image_tokens != n_image_features: |
|
|
raise ValueError( |
|
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
|
|
) |
|
|
|
|
|
mask = input_ids == self.config.image_token_id |
|
|
mask_unsqueezed = mask.unsqueeze(-1) |
|
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
|
|
image_mask = mask_expanded.to(inputs_embeds.device) |
|
|
|
|
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
|
|
|
if pixel_values_videos is not None: |
|
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype) |
|
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) |
|
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item() |
|
|
n_video_features = video_embeds.shape[0] |
|
|
if n_video_tokens != n_video_features: |
|
|
raise ValueError( |
|
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
|
|
) |
|
|
|
|
|
mask = input_ids == self.config.video_token_id |
|
|
mask_unsqueezed = mask.unsqueeze(-1) |
|
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
|
|
video_mask = mask_expanded.to(inputs_embeds.device) |
|
|
|
|
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.to(inputs_embeds.device) |
|
|
|
|
|
|
|
|
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): |
|
|
|
|
|
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: |
|
|
position_ids, rope_deltas = self.get_rope_index( |
|
|
input_ids, |
|
|
image_grid_thw, |
|
|
video_grid_thw, |
|
|
second_per_grid_ts, |
|
|
attention_mask, |
|
|
) |
|
|
self.rope_deltas = rope_deltas |
|
|
|
|
|
else: |
|
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
|
delta = ( |
|
|
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device) |
|
|
if cache_position is not None |
|
|
else 0 |
|
|
) |
|
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
|
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
|
|
if cache_position is not None: |
|
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
|
|
position_ids = position_ids.add(delta) |
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=None, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
logits = logits.float() |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return Qwen2_5_VLCausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
rope_deltas=self.rope_deltas, |
|
|
) |