| from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast |
| from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast |
| import torch |
| from typing import Optional, List, Union, Tuple |
| from torch.nn import CrossEntropyLoss |
| import numpy as np |
| import transformers.models.qwen2_vl.modeling_qwen2_vl |
| import transformers.models.qwen2_5_vl.modeling_qwen2_5_vl |
| from flash_attn.layers.rotary import apply_rotary_emb |
| from liger_kernel.transformers.fused_linear_cross_entropy import ( |
| LigerFusedLinearCrossEntropyLoss |
| ) |
| from liger_kernel.transformers.swiglu import LigerSwiGLUMLP |
| from liger_kernel.transformers.rms_norm import LigerRMSNorm |
| from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb |
|
|
|
|
| def apply_rotary_pos_emb_flashatt_fp32(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: |
| tensor_ = tensor.float() |
| cos = freqs.cos().float() |
| sin = freqs.sin().float() |
| output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) |
| return output |
|
|
| def replace_qwen_2_with_mixed_modality_forward(use_liger=True): |
| if use_liger: |
| transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen_2_mixed_modality_forward_with_flce |
| else: |
| transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = qwen_2_mixed_modality_forward |
|
|
| def replace_qwen2_5_with_mixed_modality_forward(use_liger=True): |
| if use_liger: |
| transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_mixed_modality_forward_with_flce |
| transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = (apply_rotary_pos_emb_flashatt_fp32) |
| transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2MLP = LigerSwiGLUMLP |
| transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm = LigerRMSNorm |
| transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.apply_multimodal_rotary_pos_emb = (liger_multimodal_rotary_pos_emb) |
| else: |
| transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = qwen2_5_mixed_modality_forward |
| transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = (apply_rotary_pos_emb_flashatt_fp32) |
|
|
| def qwen_2_mixed_modality_forward_with_flce( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| pixel_values_videos: Optional[torch.FloatTensor] = None, |
| image_grid_thw: Optional[torch.LongTensor] = None, |
| video_grid_thw: Optional[torch.LongTensor] = None, |
| rope_deltas: Optional[torch.LongTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ): |
| |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.model.embed_tokens(input_ids) |
|
|
| |
| if pixel_values is None and pixel_values_videos is None: |
| |
| dummy_pixel = torch.zeros(14308, 1176).to(self.visual.get_device()) |
| dummy_grid = torch.tensor([[1, 98, 146]]).to(self.visual.get_device()) |
| |
| dummy_pixel = dummy_pixel.type(self.visual.get_dtype()) |
| image_embeds = self.visual(dummy_pixel, grid_thw=dummy_grid) |
| |
| |
| |
| inputs_embeds += image_embeds.mean() * 0 |
|
|
| if pixel_values is not None: |
| pixel_values = pixel_values.type(self.visual.get_dtype()) |
| image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
| n_image_tokens = (input_ids == self.config.image_token_id).sum().item() |
| n_image_features = image_embeds.shape[0] |
| if n_image_tokens != n_image_features: |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
| image_mask = ( |
| (input_ids == self.config.image_token_id) |
| .unsqueeze(-1) |
| .expand_as(inputs_embeds) |
| .to(inputs_embeds.device) |
| ) |
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
| if pixel_values_videos is not None: |
| pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) |
| video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) |
| n_video_tokens = (input_ids == self.config.video_token_id).sum().item() |
| n_video_features = video_embeds.shape[0] |
| if n_video_tokens != n_video_features: |
| raise ValueError( |
| f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
| ) |
| video_mask = ( |
| (input_ids == self.config.video_token_id) |
| .unsqueeze(-1) |
| .expand_as(inputs_embeds) |
| .to(inputs_embeds.device) |
| ) |
| video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
| if attention_mask is not None: |
| attention_mask = attention_mask.to(inputs_embeds.device) |
|
|
| |
| if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): |
| |
| if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: |
| position_ids, rope_deltas = self.get_rope_index( |
| input_ids, image_grid_thw, video_grid_thw, attention_mask |
| ) |
| self.rope_deltas = rope_deltas |
| |
| else: |
| batch_size, seq_length, _ = inputs_embeds.shape |
| delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 |
| position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
| position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
| if cache_position is not None: |
| delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
| position_ids = position_ids.add(delta) |
| position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
| outputs = self.model( |
| input_ids=None, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = outputs[0] |
|
|
| loss = None |
| logits = None |
|
|
| if self.training and (labels is not None): |
| shift_hidden_states = hidden_states[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
|
|
| |
| shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) |
| shift_labels = shift_labels.view(-1) |
|
|
| lce = LigerFusedLinearCrossEntropyLoss() |
| loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) |
| else: |
| logits = self.lm_head(hidden_states) |
| if labels is not None: |
| |
| logits = logits.float() |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return Qwen2VLCausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| rope_deltas=self.rope_deltas, |
| ) |
|
|
| def qwen_2_mixed_modality_forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| pixel_values_videos: Optional[torch.FloatTensor] = None, |
| image_grid_thw: Optional[torch.LongTensor] = None, |
| video_grid_thw: Optional[torch.LongTensor] = None, |
| rope_deltas: Optional[torch.LongTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ): |
| |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.model.embed_tokens(input_ids) |
|
|
| |
| if pixel_values is None and pixel_values_videos is None: |
| |
| dummy_pixel = torch.zeros(14308, 1176).to(self.visual.get_device()) |
| dummy_grid = torch.tensor([[1, 98, 146]]).to(self.visual.get_device()) |
| |
| dummy_pixel = dummy_pixel.type(self.visual.get_dtype()) |
| image_embeds = self.visual(dummy_pixel, grid_thw=dummy_grid) |
| |
| |
| |
| inputs_embeds += image_embeds.mean() * 0 |
|
|
| if pixel_values is not None: |
| pixel_values = pixel_values.type(self.visual.get_dtype()) |
| image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
| n_image_tokens = (input_ids == self.config.image_token_id).sum().item() |
| n_image_features = image_embeds.shape[0] |
| if n_image_tokens != n_image_features: |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
| image_mask = ( |
| (input_ids == self.config.image_token_id) |
| .unsqueeze(-1) |
| .expand_as(inputs_embeds) |
| .to(inputs_embeds.device) |
| ) |
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
| if pixel_values_videos is not None: |
| pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) |
| video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) |
| n_video_tokens = (input_ids == self.config.video_token_id).sum().item() |
| n_video_features = video_embeds.shape[0] |
| if n_video_tokens != n_video_features: |
| raise ValueError( |
| f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
| ) |
| video_mask = ( |
| (input_ids == self.config.video_token_id) |
| .unsqueeze(-1) |
| .expand_as(inputs_embeds) |
| .to(inputs_embeds.device) |
| ) |
| video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
| if attention_mask is not None: |
| attention_mask = attention_mask.to(inputs_embeds.device) |
|
|
| |
| if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): |
| |
| if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: |
| position_ids, rope_deltas = self.get_rope_index( |
| input_ids, image_grid_thw, video_grid_thw, attention_mask |
| ) |
| self.rope_deltas = rope_deltas |
| |
| else: |
| batch_size, seq_length, _ = inputs_embeds.shape |
| delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 |
| position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
| position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
| if cache_position is not None: |
| delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
| position_ids = position_ids.add(delta) |
| position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
| outputs = self.model( |
| input_ids=None, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = outputs[0] |
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| |
| logits = logits.float() |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return Qwen2VLCausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| rope_deltas=self.rope_deltas, |
| ) |
|
|
| def qwen2_5_mixed_modality_forward_with_flce( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| pixel_values_videos: Optional[torch.FloatTensor] = None, |
| image_grid_thw: Optional[torch.LongTensor] = None, |
| video_grid_thw: Optional[torch.LongTensor] = None, |
| rope_deltas: Optional[torch.LongTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| second_per_grid_ts: Optional[torch.Tensor] = None, |
| ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.model.embed_tokens(input_ids) |
| |
| |
| if pixel_values is None and pixel_values_videos is None: |
| |
| dummy_pixel = torch.zeros(14308, 1176).to(self.visual.device) |
| dummy_grid = torch.tensor([[1, 98, 146]]).to(self.visual.device) |
| |
| dummy_pixel = dummy_pixel.type(self.visual.dtype) |
| image_embeds = self.visual(dummy_pixel, grid_thw=dummy_grid) |
| |
| |
| |
| inputs_embeds += image_embeds.mean() * 0 |
| |
| if pixel_values is not None: |
| pixel_values = pixel_values.type(self.visual.dtype) |
| image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
| n_image_tokens = (input_ids == self.config.image_token_id).sum().item() |
| n_image_features = image_embeds.shape[0] |
| if n_image_tokens != n_image_features: |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
|
|
| mask = input_ids == self.config.image_token_id |
| mask_unsqueezed = mask.unsqueeze(-1) |
| mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
| image_mask = mask_expanded.to(inputs_embeds.device) |
|
|
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
| if pixel_values_videos is not None: |
| pixel_values_videos = pixel_values_videos.type(self.visual.dtype) |
| video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) |
| n_video_tokens = (input_ids == self.config.video_token_id).sum().item() |
| n_video_features = video_embeds.shape[0] |
| if n_video_tokens != n_video_features: |
| raise ValueError( |
| f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
| ) |
|
|
| mask = input_ids == self.config.video_token_id |
| mask_unsqueezed = mask.unsqueeze(-1) |
| mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
| video_mask = mask_expanded.to(inputs_embeds.device) |
|
|
| video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
| if attention_mask is not None: |
| attention_mask = attention_mask.to(inputs_embeds.device) |
|
|
| |
| if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): |
| |
| if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: |
| position_ids, rope_deltas = self.get_rope_index( |
| input_ids, |
| image_grid_thw, |
| video_grid_thw, |
| second_per_grid_ts, |
| attention_mask, |
| ) |
| self.rope_deltas = rope_deltas |
| |
| else: |
| batch_size, seq_length, _ = inputs_embeds.shape |
| delta = ( |
| (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) |
| if cache_position is not None |
| else 0 |
| ) |
| position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
| position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
| if cache_position is not None: |
| delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
| position_ids = position_ids.add(delta) |
| position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
| outputs = self.model( |
| input_ids=None, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = outputs[0] |
| |
| loss = None |
| logits = None |
|
|
| if self.training and (labels is not None): |
| shift_hidden_states = hidden_states[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
|
|
| |
| shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) |
| shift_labels = shift_labels.view(-1) |
|
|
| lce = LigerFusedLinearCrossEntropyLoss() |
| loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) |
| else: |
| logits = self.lm_head(hidden_states) |
| if labels is not None: |
| |
| logits = logits.float() |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return Qwen2_5_VLCausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| rope_deltas=self.rope_deltas, |
| ) |
|
|
|
|
| |
| def qwen2_5_mixed_modality_forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| pixel_values_videos: Optional[torch.FloatTensor] = None, |
| image_grid_thw: Optional[torch.LongTensor] = None, |
| video_grid_thw: Optional[torch.LongTensor] = None, |
| rope_deltas: Optional[torch.LongTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| second_per_grid_ts: Optional[torch.Tensor] = None, |
| ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.model.embed_tokens(input_ids) |
| |
| |
| if pixel_values is None and pixel_values_videos is None: |
| |
| dummy_pixel = torch.zeros(14308, 1176).to(self.visual.device) |
| dummy_grid = torch.tensor([[1, 98, 146]]).to(self.visual.device) |
| |
| dummy_pixel = dummy_pixel.type(self.visual.dtype) |
| image_embeds = self.visual(dummy_pixel, grid_thw=dummy_grid) |
| |
| |
| |
| inputs_embeds += image_embeds.mean() * 0 |
| |
| if pixel_values is not None: |
| pixel_values = pixel_values.type(self.visual.dtype) |
| image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
| n_image_tokens = (input_ids == self.config.image_token_id).sum().item() |
| n_image_features = image_embeds.shape[0] |
| if n_image_tokens != n_image_features: |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
|
|
| mask = input_ids == self.config.image_token_id |
| mask_unsqueezed = mask.unsqueeze(-1) |
| mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
| image_mask = mask_expanded.to(inputs_embeds.device) |
|
|
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
| if pixel_values_videos is not None: |
| pixel_values_videos = pixel_values_videos.type(self.visual.dtype) |
| video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) |
| n_video_tokens = (input_ids == self.config.video_token_id).sum().item() |
| n_video_features = video_embeds.shape[0] |
| if n_video_tokens != n_video_features: |
| raise ValueError( |
| f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
| ) |
|
|
| mask = input_ids == self.config.video_token_id |
| mask_unsqueezed = mask.unsqueeze(-1) |
| mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
| video_mask = mask_expanded.to(inputs_embeds.device) |
|
|
| video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
| if attention_mask is not None: |
| attention_mask = attention_mask.to(inputs_embeds.device) |
|
|
| |
| if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): |
| |
| if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: |
| position_ids, rope_deltas = self.get_rope_index( |
| input_ids, |
| image_grid_thw, |
| video_grid_thw, |
| second_per_grid_ts, |
| attention_mask, |
| ) |
| self.rope_deltas = rope_deltas |
| |
| else: |
| batch_size, seq_length, _ = inputs_embeds.shape |
| delta = ( |
| (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) |
| if cache_position is not None |
| else 0 |
| ) |
| position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
| position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
| if cache_position is not None: |
| delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
| position_ids = position_ids.add(delta) |
| position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
| outputs = self.model( |
| input_ids=None, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = outputs[0] |
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| |
| logits = logits.float() |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return Qwen2_5_VLCausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| rope_deltas=self.rope_deltas, |
| ) |