|
|
from typing import Optional, List |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
from transformers import PreTrainedModel, AutoModel, AutoModelForCausalLM, Qwen2_5_VLForConditionalGeneration |
|
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast |
|
|
from typing import List, Optional, Tuple, Union, Dict |
|
|
import torch |
|
|
from torch.nn import CrossEntropyLoss |
|
|
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from transformers.generation.utils import GenerateOutput |
|
|
from .configuration_qqmm import QQMMConfig |
|
|
|
|
|
def _prepare_4d_causal_attention_mask_with_cache_position( |
|
|
attention_mask: torch.Tensor, |
|
|
sequence_length: int, |
|
|
target_length: int, |
|
|
dtype: torch.dtype, |
|
|
device: torch.device, |
|
|
min_dtype: float, |
|
|
cache_position: torch.Tensor, |
|
|
batch_size: int, |
|
|
): |
|
|
""" |
|
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
|
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
|
|
|
|
|
Args: |
|
|
attention_mask (`torch.Tensor`): |
|
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. |
|
|
sequence_length (`int`): |
|
|
The sequence length being processed. |
|
|
target_length (`int`): |
|
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. |
|
|
dtype (`torch.dtype`): |
|
|
The dtype to use for the 4D attention mask. |
|
|
device (`torch.device`): |
|
|
The device to plcae the 4D attention mask on. |
|
|
min_dtype (`float`): |
|
|
The minimum value representable with the dtype `dtype`. |
|
|
cache_position (`torch.Tensor`): |
|
|
Indices depicting the position of the input sequence tokens in the sequence. |
|
|
batch_size (`torch.Tensor`): |
|
|
Batch size. |
|
|
""" |
|
|
if attention_mask is not None and attention_mask.dim() == 4: |
|
|
|
|
|
causal_mask = attention_mask |
|
|
else: |
|
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) |
|
|
if sequence_length != 1: |
|
|
causal_mask = torch.triu(causal_mask, diagonal=1) |
|
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
|
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
|
|
if attention_mask is not None: |
|
|
causal_mask = causal_mask.clone() |
|
|
mask_length = attention_mask.shape[-1] |
|
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
|
|
padding_mask = padding_mask == 0 |
|
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
|
|
padding_mask, min_dtype |
|
|
) |
|
|
|
|
|
return causal_mask |
|
|
|
|
|
def padcat_sequences(sequences, value=0, pad_side='right'): |
|
|
if all(s is None for s in sequences): |
|
|
return None |
|
|
max_l = max(s.size(1) for s in sequences) |
|
|
sequences_ = [] |
|
|
for seq in sequences: |
|
|
if seq.size(1) != max_l: |
|
|
pad_len = max_l - seq.size(1) |
|
|
pad_len = (0, pad_len) if pad_side == 'right' else (pad_len, 0) |
|
|
seq = F.pad(seq, pad_len, value=value) |
|
|
sequences_.append(seq) |
|
|
|
|
|
sequences = torch.cat(sequences_) |
|
|
|
|
|
return sequences |
|
|
|
|
|
class QQMMPreTrainedModel(PreTrainedModel): |
|
|
config_class = QQMMConfig |
|
|
supports_gradient_checkpointing = True |
|
|
_skip_keys_device_placement = "past_key_values" |
|
|
_supports_cache_class = True |
|
|
_supports_flash_attn_2 = True |
|
|
_supports_sdpa = True |
|
|
|
|
|
class QQMMForCausalLM(QQMMPreTrainedModel): |
|
|
|
|
|
def __init__(self, |
|
|
config, |
|
|
qwen2_5_vl_model=None): |
|
|
|
|
|
super().__init__(config) |
|
|
if qwen2_5_vl_model is None: |
|
|
kwargs_ = {} |
|
|
if config._attn_implementation_internal is not None: |
|
|
kwargs_['attn_implementation'] = config._attn_implementation_internal |
|
|
model = Qwen2_5_VLForConditionalGeneration(config.model_config) |
|
|
|
|
|
|
|
|
else: |
|
|
model = qwen2_5_vl_model |
|
|
self.qwen2_5_vl_model = model |
|
|
self.post_init() |
|
|
|
|
|
def make_diy_mask(self, input_ids, attention_mask, embed_token_id, im_start_id, im_end_id): |
|
|
if len(attention_mask.shape) == 2: |
|
|
sequence_length = attention_mask.shape[1] |
|
|
target_length = attention_mask.shape[1] |
|
|
dtype = torch.bfloat16 |
|
|
device = input_ids.device |
|
|
min_dtype = torch.finfo(dtype).min |
|
|
cache_position = torch.arange(0, sequence_length, device=attention_mask.device) |
|
|
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
|
|
attention_mask, |
|
|
sequence_length=sequence_length, |
|
|
target_length=target_length, |
|
|
dtype=dtype, |
|
|
device=device, |
|
|
min_dtype=min_dtype, |
|
|
cache_position=cache_position, |
|
|
batch_size=attention_mask.shape[0], |
|
|
) |
|
|
else: |
|
|
dtype = torch.bfloat16 |
|
|
min_dtype = torch.finfo(dtype).min |
|
|
mask = input_ids == embed_token_id |
|
|
embed_index = torch.argmax(mask.float(), dim=1) |
|
|
embed_index[embed_index==0] = input_ids.shape[1] |
|
|
embed_index = embed_index.view(-1, ) |
|
|
mask = input_ids == im_start_id |
|
|
im_start_index_tmp = torch.argmax(mask.float(), dim=1).view(-1, 1) |
|
|
mask = torch.scatter(mask, dim=1, index=im_start_index_tmp, value=False) |
|
|
im_start_index = torch.argmax(mask.float(), dim=1).view(-1, ) |
|
|
mask = input_ids == im_end_id |
|
|
im_end_index_tmp = torch.argmax(mask.float(), dim=1).view(-1, 1) |
|
|
mask = torch.scatter(mask, dim=1, index=im_end_index_tmp, value=False) |
|
|
im_end_index = torch.argmax(mask.float(), dim=1).view(-1, ) |
|
|
for b in range(attention_mask.shape[0]): |
|
|
attention_mask[b, 0, embed_index[b]+1:, im_start_index[b]:im_end_index[b]+2] = min_dtype |
|
|
return attention_mask |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
pixel_values: Optional[torch.Tensor] = None, |
|
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
|
rope_deltas: Optional[torch.LongTensor] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
second_per_grid_ts: Optional[torch.Tensor] = None, |
|
|
embed_token_id: Optional[int] = None, |
|
|
return_emb: Optional[bool] = False, |
|
|
cal_loss: Optional[bool] = False |
|
|
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: |
|
|
|
|
|
if pixel_values is not None and pixel_values.shape[0] == 0: |
|
|
pixel_values = None |
|
|
image_grid_thw = None |
|
|
output_attentions = output_attentions if output_attentions is not None else self.qwen2_5_vl_model.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.qwen2_5_vl_model.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.qwen2_5_vl_model.config.use_return_dict |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.qwen2_5_vl_model.model.embed_tokens(input_ids) |
|
|
if pixel_values is not None: |
|
|
pixel_values = pixel_values.type(self.qwen2_5_vl_model.visual.dtype) |
|
|
image_embeds = self.qwen2_5_vl_model.visual(pixel_values, grid_thw=image_grid_thw) |
|
|
n_image_tokens = (input_ids == self.qwen2_5_vl_model.config.image_token_id).sum().item() |
|
|
n_image_features = image_embeds.shape[0] |
|
|
if n_image_tokens != n_image_features: |
|
|
raise ValueError( |
|
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
|
|
) |
|
|
|
|
|
mask = input_ids == self.qwen2_5_vl_model.config.image_token_id |
|
|
mask_unsqueezed = mask.unsqueeze(-1) |
|
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
|
|
image_mask = mask_expanded.to(inputs_embeds.device) |
|
|
|
|
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
|
|
|
if pixel_values_videos is not None: |
|
|
pixel_values_videos = pixel_values_videos.type(self.qwen2_5_vl_model.visual.dtype) |
|
|
video_embeds = self.qwen2_5_vl_model.visual(pixel_values_videos, grid_thw=video_grid_thw) |
|
|
n_video_tokens = (input_ids == self.qwen2_5_vl_model.config.video_token_id).sum().item() |
|
|
n_video_features = video_embeds.shape[0] |
|
|
if n_video_tokens != n_video_features: |
|
|
raise ValueError( |
|
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
|
|
) |
|
|
|
|
|
mask = input_ids == self.qwen2_5_vl_model.config.video_token_id |
|
|
mask_unsqueezed = mask.unsqueeze(-1) |
|
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
|
|
video_mask = mask_expanded.to(inputs_embeds.device) |
|
|
|
|
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
|
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.to(inputs_embeds.device) |
|
|
|
|
|
|
|
|
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): |
|
|
|
|
|
if ( |
|
|
(cache_position is not None and cache_position[0] == 0) |
|
|
or self.qwen2_5_vl_model.rope_deltas is None |
|
|
or (past_key_values is None or past_key_values.get_seq_length() == 0) |
|
|
): |
|
|
position_ids, rope_deltas = self.qwen2_5_vl_model.get_rope_index( |
|
|
input_ids, |
|
|
image_grid_thw, |
|
|
video_grid_thw, |
|
|
second_per_grid_ts, |
|
|
attention_mask, |
|
|
) |
|
|
self.rope_deltas = rope_deltas |
|
|
|
|
|
else: |
|
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
|
delta = ( |
|
|
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device) |
|
|
if cache_position is not None |
|
|
else 0 |
|
|
) |
|
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
|
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
|
|
if cache_position is not None: |
|
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
|
|
position_ids = position_ids.add(delta) |
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
|
|
|
outputs = self.qwen2_5_vl_model.model( |
|
|
input_ids=None, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
|
|
|
if labels is not None: |
|
|
mask = labels == embed_token_id |
|
|
labels[mask] = -100 |
|
|
|
|
|
logits = self.qwen2_5_vl_model.lm_head(hidden_states) |
|
|
|
|
|
if return_emb: |
|
|
assert labels is not None, 'labels must be provided to obtain embed' |
|
|
hidden_index = torch.argmax(mask.float(), dim=1) |
|
|
hidden_index[hidden_index==0] = labels.shape[1] |
|
|
hidden_states = torch.gather(hidden_states, dim=1, index=(hidden_index-1).view(hidden_index.shape[0], 1, 1).repeat(1, 1, hidden_states.shape[-1])) |
|
|
emb = hidden_states[:, 0, :].contiguous() |
|
|
else: |
|
|
emb = None |
|
|
|
|
|
loss = None |
|
|
if labels is not None and cal_loss: |
|
|
|
|
|
logits = logits.float() |
|
|
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
if (shift_labels < 0).all().item(): |
|
|
loss = 0.0 |
|
|
else: |
|
|
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
shift_logits = shift_logits.view(-1, self.qwen2_5_vl_model.config.vocab_size) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[1:] |
|
|
return (loss,) + output if loss is not None else output |
|
|
else: |
|
|
outputs = Qwen2_5_VLCausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
rope_deltas=self.qwen2_5_vl_model.rope_deltas, |
|
|
) |
|
|
if emb is not None: |
|
|
outputs['emb'] = emb |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, input_ids, *args, **kwargs) -> Union[GenerateOutput, torch.LongTensor]: |
|
|
return self.qwen2_5_vl_model.generate(input_ids, *args, **kwargs) |
|
|
|
|
|
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
|
|
super().gradient_checkpointing_enable(gradient_checkpointing_kwargs) |
|
|
self.qwen2_5_vl_model.model.enable_input_require_grads() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.qwen2_5_vl_model.model.get_input_embeddings() |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.qwen2_5_vl_model.model.set_input_embeddings(value) |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.qwen2_5_vl_model.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.qwen2_5_vl_model.lm_head = new_embeddings |
|
|
|
|
|
def set_decoder(self, decoder): |
|
|
self.qwen2_5_vl_model.model = decoder |
|
|
|
|
|
def get_decoder(self): |
|
|
return self.qwen2_5_vl_model.model |