| | import torch |
| | import torch.nn as nn |
| |
|
| | from typing import Optional, Union, Tuple |
| |
|
| | from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import ( |
| | shift_tokens_right, |
| | VisionEncoderDecoderModel |
| | ) |
| | from transformers.modeling_outputs import Seq2SeqLMOutput |
| | from transformers import PreTrainedModel |
| | from transformers.models.pixtral.modeling_pixtral import apply_rotary_pos_emb, PixtralAttention, PixtralVisionModel |
| | from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
| | from transformers.modeling_outputs import BaseModelOutput |
| |
|
| | from pixtral_encoder_decoder.config import PixtralVisionModelBatchConfig, VisionPixtralEncoderDecoderConfig |
| |
|
| |
|
| | def position_ids_in_meshgrid_batch(patch_embeds, max_width): |
| | """get the position ids of the batch. """ |
| | |
| | height, width = patch_embeds.shape[-2:] |
| | mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") |
| | h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) |
| | ids = h_grid * max_width + v_grid |
| | |
| | ids = ids.reshape(1, -1).repeat(patch_embeds.shape[0], 1) |
| | return ids |
| |
|
| |
|
| | def create_attention_mask_batch(w, h, image_sizes, patch_size): |
| | def foo(i, j): |
| | return ((torch.arange(h).unsqueeze(1) < i) & (torch.arange(w).unsqueeze(0) < j)).float() |
| |
|
| | mask = [foo(size[0] // patch_size, size[1] // patch_size) for size in image_sizes] |
| | return torch.stack(mask, dim=0) |
| |
|
| |
|
| | def pixtral_attention_fix_forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| | output_attentions: Optional[bool] = False, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | """Input shape: Batch x Time x Channel""" |
| |
|
| | batch_size, patches, _ = hidden_states.size() |
| |
|
| | query_states = self.q_proj(hidden_states) |
| | key_states = self.k_proj(hidden_states) |
| | value_states = self.v_proj(hidden_states) |
| |
|
| | query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) |
| | key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) |
| | value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) |
| |
|
| | cos, sin = position_embeddings |
| | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1) |
| |
|
| | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale |
| |
|
| | if attention_mask is not None: |
| | attn_weights = attn_weights + attention_mask |
| |
|
| | |
| | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| | attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
| | attn_output = torch.matmul(attn_weights, value_states) |
| |
|
| | attn_output = attn_output.transpose(1, 2).contiguous() |
| | attn_output = attn_output.reshape(batch_size, patches, -1) |
| |
|
| | attn_output = self.o_proj(attn_output) |
| |
|
| | return attn_output, attn_weights |
| |
|
| |
|
| | |
| | PixtralAttention.forward = pixtral_attention_fix_forward |
| |
|
| |
|
| | class PixtralVisionModelBatch(PixtralVisionModel): |
| | config_class = PixtralVisionModelBatchConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | def forward( |
| | self, |
| | pixel_values: torch.Tensor, |
| | image_sizes: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | *args, |
| | **kwargs, |
| | ) -> Union[Tuple, BaseModelOutput]: |
| | """ |
| | Returns: |
| | pixel_values: tensor of token features for |
| | all tokens of all images of shape (N_toks, D) |
| | """ |
| | if attention_mask is None and image_sizes is None: |
| | raise ValueError("Either `attention_mask` or `image_sizes` must be defined") |
| | |
| | patch_embeds = self.patch_conv(pixel_values) |
| | |
| | if attention_mask is None: |
| | h, w = patch_embeds.shape[-2:] |
| | attention_mask = create_attention_mask_batch(w, h, image_sizes, self.patch_size).to(patch_embeds.device) |
| | attention_mask = attention_mask.flatten(start_dim=-2) |
| |
|
| | |
| | position_ids = position_ids_in_meshgrid_batch( |
| | patch_embeds, max_width=self.config.image_size // self.config.patch_size |
| | ) |
| | position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids) |
| |
|
| | |
| | |
| | patch_embeds = patch_embeds.flatten(start_dim=-2).transpose(-1, -2) |
| |
|
| | attention_mask = _prepare_4d_attention_mask(attention_mask, torch.float) |
| |
|
| | patch_embeds = self.ln_pre(patch_embeds) |
| |
|
| | out = self.transformer( |
| | patch_embeds, |
| | attention_mask=attention_mask, |
| | position_embeddings=position_embeddings, |
| | output_hidden_states=output_hidden_states, |
| | output_attentions=output_attentions, |
| | return_dict=return_dict, |
| | ) |
| | return out |
| |
|
| |
|
| | class VisionPixtralEncoderDecoder(VisionEncoderDecoderModel): |
| | config_class = VisionPixtralEncoderDecoderConfig |
| |
|
| | def __init__(self, config, |
| | encoder: Optional[PixtralVisionModelBatch] = None, |
| | decoder: Optional[PreTrainedModel] = None): |
| | super().__init__(config, encoder, decoder) |
| |
|
| | def forward( |
| | self, |
| | pixel_values: Optional[torch.Tensor] = None, |
| | image_sizes: Optional[torch.Tensor] = None, |
| | encoder_attention_mask: Optional[torch.Tensor] = None, |
| | decoder_input_ids: Optional[torch.LongTensor] = None, |
| | decoder_attention_mask: Optional[torch.BoolTensor] = None, |
| | encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, |
| | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | **kwargs, |
| | ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| | num_items_in_batch = kwargs.pop("num_items_in_batch", None) |
| |
|
| | kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} |
| |
|
| | kwargs_decoder = { |
| | argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_") |
| | } |
| |
|
| | if encoder_outputs is None: |
| | if pixel_values is None: |
| | raise ValueError("You have to specify pixel_values") |
| | if encoder_attention_mask is None and image_sizes is None: |
| | raise ValueError("Either `encoder_attention_mask` or `image_sizes` must be defined") |
| | if encoder_attention_mask is None: |
| | h, w = pixel_values.shape[-2:] |
| | h = h // self.encoder.patch_size |
| | w = w // self.encoder.patch_size |
| | encoder_attention_mask = create_attention_mask_batch(w, h, image_sizes, self.encoder.patch_size) |
| | encoder_attention_mask = encoder_attention_mask.to(pixel_values.device) |
| | encoder_attention_mask = encoder_attention_mask.flatten(start_dim=-2) |
| |
|
| | encoder_outputs = self.encoder( |
| | pixel_values=pixel_values, |
| | image_sizes=image_sizes, |
| | attention_mask=encoder_attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | **kwargs_encoder, |
| | ) |
| | elif isinstance(encoder_outputs, tuple): |
| | encoder_outputs = BaseModelOutput(*encoder_outputs) |
| |
|
| | encoder_hidden_states = encoder_outputs[0] |
| |
|
| | |
| | if ( |
| | self.encoder.config.hidden_size != self.decoder.config.hidden_size |
| | and self.decoder.config.cross_attention_hidden_size is None |
| | ): |
| | encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) |
| |
|
| | |
| | |
| |
|
| | if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): |
| | decoder_input_ids = shift_tokens_right( |
| | labels, self.config.pad_token_id, self.config.decoder_start_token_id |
| | ) |
| |
|
| | |
| | decoder_outputs = self.decoder( |
| | input_ids=decoder_input_ids, |
| | attention_mask=decoder_attention_mask, |
| | encoder_hidden_states=encoder_hidden_states, |
| | encoder_attention_mask=encoder_attention_mask, |
| | inputs_embeds=decoder_inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | use_cache=use_cache, |
| | past_key_values=past_key_values, |
| | return_dict=return_dict, |
| | **kwargs_decoder, |
| | ) |
| |
|
| | |
| | loss = None |
| | if labels is not None: |
| | logits = decoder_outputs.logits if return_dict else decoder_outputs[0] |
| |
|
| | loss = self.loss_function( |
| | logits=logits, |
| | labels=labels, |
| | vocab_size=self.decoder.config.vocab_size, |
| | num_items_in_batch=num_items_in_batch, |
| | ) |
| |
|
| | if not return_dict: |
| | if loss is not None: |
| | return (loss,) + decoder_outputs + encoder_outputs |
| | else: |
| | return decoder_outputs + encoder_outputs |
| |
|
| | return Seq2SeqLMOutput( |
| | loss=loss, |
| | logits=decoder_outputs.logits, |
| | past_key_values=decoder_outputs.past_key_values, |
| | decoder_hidden_states=decoder_outputs.hidden_states, |
| | decoder_attentions=decoder_outputs.attentions, |
| | cross_attentions=decoder_outputs.cross_attentions, |
| | encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
| | encoder_hidden_states=encoder_outputs.hidden_states, |
| | encoder_attentions=encoder_outputs.attentions, |
| | ) |
| |
|