deepeyes_convert / modeling_qqmm.py
WINDop's picture
Upload folder using huggingface_hub
83de97c verified
from typing import Optional, List
import torch
from torch import nn
from torch.nn import functional as F
from transformers import PreTrainedModel, AutoModel, AutoModelForCausalLM, Qwen2_5_VLForConditionalGeneration
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
from typing import List, Optional, Tuple, Union, Dict
import torch
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from .configuration_qqmm import QQMMConfig
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def padcat_sequences(sequences, value=0, pad_side='right'):
if all(s is None for s in sequences):
return None
max_l = max(s.size(1) for s in sequences)
sequences_ = []
for seq in sequences:
if seq.size(1) != max_l:
pad_len = max_l - seq.size(1)
pad_len = (0, pad_len) if pad_side == 'right' else (pad_len, 0)
seq = F.pad(seq, pad_len, value=value)
sequences_.append(seq)
sequences = torch.cat(sequences_)
return sequences
class QQMMPreTrainedModel(PreTrainedModel):
config_class = QQMMConfig
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
class QQMMForCausalLM(QQMMPreTrainedModel):
def __init__(self,
config,
qwen2_5_vl_model=None):
super().__init__(config)
if qwen2_5_vl_model is None:
kwargs_ = {}
if config._attn_implementation_internal is not None:
kwargs_['attn_implementation'] = config._attn_implementation_internal
model = Qwen2_5_VLForConditionalGeneration(config.model_config)
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained("/group/40048/windzhchen/pretrain_models/deepeyes_convert")
else:
model = qwen2_5_vl_model
self.qwen2_5_vl_model = model
self.post_init()
def make_diy_mask(self, input_ids, attention_mask, embed_token_id, im_start_id, im_end_id):
if len(attention_mask.shape) == 2:
sequence_length = attention_mask.shape[1]
target_length = attention_mask.shape[1]
dtype = torch.bfloat16
device = input_ids.device
min_dtype = torch.finfo(dtype).min
cache_position = torch.arange(0, sequence_length, device=attention_mask.device)
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=attention_mask.shape[0],
)
else:
dtype = torch.bfloat16
min_dtype = torch.finfo(dtype).min
mask = input_ids == embed_token_id
embed_index = torch.argmax(mask.float(), dim=1)
embed_index[embed_index==0] = input_ids.shape[1]
embed_index = embed_index.view(-1, )
mask = input_ids == im_start_id
im_start_index_tmp = torch.argmax(mask.float(), dim=1).view(-1, 1)
mask = torch.scatter(mask, dim=1, index=im_start_index_tmp, value=False)
im_start_index = torch.argmax(mask.float(), dim=1).view(-1, )
mask = input_ids == im_end_id
im_end_index_tmp = torch.argmax(mask.float(), dim=1).view(-1, 1)
mask = torch.scatter(mask, dim=1, index=im_end_index_tmp, value=False)
im_end_index = torch.argmax(mask.float(), dim=1).view(-1, )
for b in range(attention_mask.shape[0]):
attention_mask[b, 0, embed_index[b]+1:, im_start_index[b]:im_end_index[b]+2] = min_dtype # <|im_start|>user\nxxxxx<|im_end|>\n
return attention_mask
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
embed_token_id: Optional[int] = None,
return_emb: Optional[bool] = False,
cal_loss: Optional[bool] = False
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
if pixel_values is not None and pixel_values.shape[0] == 0:
pixel_values = None
image_grid_thw = None
output_attentions = output_attentions if output_attentions is not None else self.qwen2_5_vl_model.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.qwen2_5_vl_model.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.qwen2_5_vl_model.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.qwen2_5_vl_model.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.qwen2_5_vl_model.visual.dtype)
image_embeds = self.qwen2_5_vl_model.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.qwen2_5_vl_model.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
mask = input_ids == self.qwen2_5_vl_model.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.qwen2_5_vl_model.visual.dtype)
video_embeds = self.qwen2_5_vl_model.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.qwen2_5_vl_model.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
mask = input_ids == self.qwen2_5_vl_model.config.video_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
or self.qwen2_5_vl_model.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
position_ids, rope_deltas = self.qwen2_5_vl_model.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
attention_mask,
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.qwen2_5_vl_model.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
if labels is not None:
mask = labels == embed_token_id
labels[mask] = -100
logits = self.qwen2_5_vl_model.lm_head(hidden_states)
if return_emb:
assert labels is not None, 'labels must be provided to obtain embed'
hidden_index = torch.argmax(mask.float(), dim=1)
hidden_index[hidden_index==0] = labels.shape[1]
hidden_states = torch.gather(hidden_states, dim=1, index=(hidden_index-1).view(hidden_index.shape[0], 1, 1).repeat(1, 1, hidden_states.shape[-1]))
emb = hidden_states[:, 0, :].contiguous() # B, C
else:
emb = None
loss = None
if labels is not None and cal_loss:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
if (shift_labels < 0).all().item():
loss = 0.0
else:
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.qwen2_5_vl_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
else:
outputs = Qwen2_5_VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.qwen2_5_vl_model.rope_deltas,
)
if emb is not None:
outputs['emb'] = emb
return outputs
@torch.no_grad()
def generate(self, input_ids, *args, **kwargs) -> Union[GenerateOutput, torch.LongTensor]:
return self.qwen2_5_vl_model.generate(input_ids, *args, **kwargs)
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
super().gradient_checkpointing_enable(gradient_checkpointing_kwargs)
self.qwen2_5_vl_model.model.enable_input_require_grads()
def get_input_embeddings(self):
return self.qwen2_5_vl_model.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.qwen2_5_vl_model.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.qwen2_5_vl_model.lm_head
def set_output_embeddings(self, new_embeddings):
self.qwen2_5_vl_model.lm_head = new_embeddings
def set_decoder(self, decoder):
self.qwen2_5_vl_model.model = decoder
def get_decoder(self):
return self.qwen2_5_vl_model.model