| | from transformers import PretrainedConfig, PreTrainedModel |
| |
|
| | import inspect |
| | import math |
| | from dataclasses import dataclass |
| | from typing import Dict, List, Optional, Tuple, Union |
| | import json |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import torch.utils.checkpoint |
| | from torch import nn |
| | from torch.nn import CrossEntropyLoss |
| |
|
| | from transformers.activations import ACT2FN |
| | from transformers.cache_utils import Cache, DynamicCache |
| | from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
| | from transformers.modeling_outputs import BaseModelOutput, ModelOutput |
| | from transformers.utils import ( |
| | add_start_docstrings, |
| | add_start_docstrings_to_model_forward, |
| | is_flash_attn_2_available, |
| | is_flash_attn_greater_or_equal_2_10, |
| | logging, |
| | replace_return_docstrings, |
| | ) |
| |
|
| | if is_flash_attn_2_available(): |
| | from flash_attn import flash_attn_func, flash_attn_varlen_func |
| | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input |
| |
|
| | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) |
| |
|
| | |
| | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| | """ |
| | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| | """ |
| | batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| | if n_rep == 1: |
| | return hidden_states |
| | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
| |
|
| | class Idefics2ConnectorConfig(PretrainedConfig): |
| | r""" |
| | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
| | documentation from [`PretrainedConfig`] for more information. |
| | |
| | Args: |
| | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): |
| | The non-linear activation function (function or string) in the perceiver block. |
| | resampler_n_latents (`int`, *optional*, defaults to 64): |
| | Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). |
| | resampler_depth (`int`, *optional*, defaults to 3): |
| | Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (<= 3). |
| | resampler_n_heads (`int`, *optional*, defaults to 16): |
| | Number of heads in each Transformer block (for multi-headed self-attention). |
| | resampler_head_dim (`int`, *optional*, defaults to 96): |
| | Dimensionality of each head projection in the Transformer block. |
| | num_key_value_heads (`int`, *optional*, defaults to 4): |
| | Number of key-value heads in the perceiver attention block. |
| | attention_dropout (`float`, *optional*, defaults to 0.0): |
| | The dropout ratio for the attention probabilities. |
| | """ |
| | _auto_class = 'AutoConfig' |
| | model_type = "Idefics2ConnectorConfig" |
| |
|
| | def __init__( |
| | self, |
| | vision_hidden_size=1152, |
| | hidden_size=4096, |
| | hidden_act="silu", |
| | resampler_n_latents=64, |
| | resampler_depth=3, |
| | rms_norm_eps=1e-05, |
| | resampler_n_heads=16, |
| | resampler_head_dim=96, |
| | num_key_value_heads=4, |
| | attention_dropout=0.0, |
| | intermediate_size=14336, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.vision_hidden_size = vision_hidden_size |
| | self.hidden_size = hidden_size |
| | self.hidden_act = hidden_act |
| | self.resampler_n_latents = resampler_n_latents |
| | self.resampler_depth = resampler_depth |
| | self.rms_norm_eps = rms_norm_eps |
| | self.resampler_n_heads = resampler_n_heads |
| | self.num_key_value_heads = num_key_value_heads |
| | self.resampler_head_dim = resampler_head_dim |
| | self.attention_dropout = attention_dropout |
| | self.intermediate_size = intermediate_size |
| | if self.num_key_value_heads > self.resampler_n_heads: |
| | raise ValueError( |
| | f"num_key_value_heads={self.num_key_value_heads} must be less than or equal to" |
| | f" resampler_n_heads={self.resampler_n_heads}" |
| | ) |
| | |
| |
|
| | @classmethod |
| | def from_pretrained(cls, config_path, **kwargs) -> "PretrainedConfig": |
| | |
| | with open(config_path, "r", encoding="utf-8") as f: |
| | config_dict = json.load(f) |
| | cls = Idefics2ConnectorConfig( |
| | vision_hidden_size=config_dict['vision_hidden_size'], |
| | hidden_size=config_dict['hidden_size'], |
| | hidden_act="silu", |
| | resampler_n_latents=config_dict['resampler_n_latents'], |
| | resampler_depth=config_dict['resampler_depth'], |
| | rms_norm_eps=config_dict['rms_norm_eps'], |
| | intermediate_size = config_dict['intermediate_size'] |
| | ) |
| | |
| | return cls |
| |
|
| | |
| | def _get_unpad_data(attention_mask): |
| | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| | max_seqlen_in_batch = seqlens_in_batch.max().item() |
| | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
| | return ( |
| | indices, |
| | cu_seqlens, |
| | max_seqlen_in_batch, |
| | ) |
| |
|
| | class Idefics2PerceiverAttention(nn.Module): |
| | def __init__(self, config, layer_idx: Optional[int] = None) -> None: |
| | """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" |
| | super().__init__() |
| |
|
| | self.layer_idx = None |
| | self.hidden_size = config.hidden_size |
| | self.num_heads = config.resampler_n_heads |
| | self.head_dim = config.resampler_head_dim |
| | self.num_key_value_heads = config.num_key_value_heads |
| | self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| | self.attention_dropout = config.attention_dropout |
| |
|
| | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
| | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
| | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) |
| | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
| |
|
| | self.is_causal = False |
| |
|
| | def forward( |
| | self, |
| | latents: torch.Tensor, |
| | context: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| | output_attentions: bool = False, |
| | use_cache: bool = False, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| | """ |
| | Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension! |
| | |
| | Args: |
| | latents (`torch.Tensor`): Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to. |
| | context (`torch.Tensor`): Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample. |
| | attention_mask (`torch.Tensor`, *optional*): Tensor of shape [bsz, 1, seq, n_latents] representing attention mask. |
| | position_ids (`torch.LongTensor`, *optional*): Tensor of shape [bsz, seq] representing position indices of each input token. |
| | past_key_value (`Tuple[torch.Tensor]`, *optional*): Tuple of tensors containing cached key and value states. |
| | output_attentions (`bool`, *optional*, defaults to `False`): Whether to return attention weights. |
| | use_cache (`bool`, *optional*, defaults to `False`): Whether to use past_key_value for caching. |
| | """ |
| | bsz, q_len, _ = latents.size() |
| | kv_seq_len = q_len + context.size()[1] |
| |
|
| | hidden_states = torch.concat([context, latents], dim=-2) |
| |
|
| | query_states = self.q_proj(latents) |
| | key_states = self.k_proj(hidden_states) |
| | value_states = self.v_proj(hidden_states) |
| |
|
| | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| | value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| |
|
| | past_key_value = getattr(self, "past_key_value", past_key_value) |
| |
|
| | if past_key_value is not None: |
| | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) |
| |
|
| | |
| | key_states = repeat_kv(key_states, self.num_key_value_groups) |
| | value_states = repeat_kv(value_states, self.num_key_value_groups) |
| |
|
| | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
| |
|
| | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
| | raise ValueError( |
| | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
| | f" {attn_weights.size()}" |
| | ) |
| |
|
| | if attention_mask is not None: |
| | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
| | raise ValueError( |
| | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
| | ) |
| |
|
| | attn_weights = attn_weights + attention_mask |
| |
|
| | |
| | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| | attn_output = torch.matmul(attn_weights, value_states) |
| |
|
| | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
| | raise ValueError( |
| | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
| | f" {attn_output.size()}" |
| | ) |
| |
|
| | attn_output = attn_output.transpose(1, 2).contiguous() |
| | attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) |
| |
|
| | attn_output = self.o_proj(attn_output) |
| |
|
| | if not output_attentions: |
| | attn_weights = None |
| |
|
| | return attn_output, attn_weights, past_key_value |
| |
|
| | |
| | class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention): |
| | """ |
| | Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays |
| | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of |
| | flash attention and deal with padding tokens in case the input contains any of them. |
| | """ |
| |
|
| | |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | |
| | |
| | |
| | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
| |
|
| | |
| | def forward( |
| | self, |
| | latents: torch.Tensor, |
| | context: torch.Tensor, |
| | attention_mask: Optional[torch.LongTensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_value: Optional[Cache] = None, |
| | output_attentions: bool = False, |
| | use_cache: bool = False, |
| | **kwargs, |
| | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| | |
| | bsz, q_len, _ = latents.size() |
| | kv_seq_len = q_len + context.size()[1] |
| |
|
| | |
| | |
| | query_states = self.q_proj(latents) |
| | key_states = self.k_proj(torch.cat([context, latents], dim=-2)) |
| | value_states = self.v_proj(torch.cat([context, latents], dim=-2)) |
| |
|
| | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| | value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| |
|
| | kv_seq_len = key_states.shape[-2] |
| | if past_key_value is not None: |
| | kv_seq_len += past_key_value[0].shape[-2] |
| |
|
| | if past_key_value is not None: |
| | |
| | if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window: |
| | slicing_tokens = kv_seq_len - self.config.sliding_window |
| |
|
| | past_key = past_key_value[0] |
| | past_value = past_key_value[1] |
| |
|
| | past_key = past_key[:, :, slicing_tokens:, :].contiguous() |
| | past_value = past_value[:, :, slicing_tokens:, :].contiguous() |
| |
|
| | if past_key.shape[-2] != self.config.sliding_window - 1: |
| | raise ValueError( |
| | "past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1," |
| | f" head_dim`), got {past_key.shape}" |
| | ) |
| |
|
| | past_key_value = (past_key, past_value) |
| |
|
| | if attention_mask is not None: |
| | attention_mask = attention_mask[:, slicing_tokens:] |
| | attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) |
| |
|
| | key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| | value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| |
|
| | past_key_value = (key_states, value_states) if use_cache else None |
| |
|
| | |
| | key_states = repeat_kv(key_states, self.num_key_value_groups) |
| | value_states = repeat_kv(value_states, self.num_key_value_groups) |
| | dropout_rate = 0.0 if not self.training else self.attention_dropout |
| |
|
| | |
| | |
| | |
| | input_dtype = query_states.dtype |
| | if input_dtype == torch.float32: |
| | if torch.is_autocast_enabled(): |
| | target_dtype = torch.get_autocast_gpu_dtype() |
| | |
| | elif hasattr(self.config, "_pre_quantization_dtype"): |
| | target_dtype = self.config._pre_quantization_dtype |
| | else: |
| | target_dtype = self.q_proj.weight.dtype |
| |
|
| | logger.warning_once( |
| | f"The input hidden states seems to be silently casted in float32, this might be related to" |
| | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
| | f" {target_dtype}." |
| | ) |
| |
|
| | query_states = query_states.to(target_dtype) |
| | key_states = key_states.to(target_dtype) |
| | value_states = value_states.to(target_dtype) |
| |
|
| | |
| | query_states = query_states.transpose(1, 2) |
| | key_states = key_states.transpose(1, 2) |
| | value_states = value_states.transpose(1, 2) |
| |
|
| | attn_output = self._flash_attention_forward( |
| | query_states, |
| | key_states, |
| | value_states, |
| | attention_mask, |
| | q_len, |
| | dropout=dropout_rate, |
| | use_sliding_windows=False, |
| | ) |
| |
|
| | attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() |
| | attn_output = self.o_proj(attn_output) |
| |
|
| | if not output_attentions: |
| | attn_weights = None |
| |
|
| | return attn_output, attn_weights, past_key_value |
| |
|
| | def _flash_attention_forward( |
| | self, |
| | query_states, |
| | key_states, |
| | value_states, |
| | attention_mask, |
| | query_length, |
| | dropout=0.0, |
| | softmax_scale=None, |
| | use_sliding_windows=False, |
| | ): |
| | """ |
| | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token |
| | first unpad the input, then computes the attention scores and pad the final attention scores. |
| | |
| | Args: |
| | query_states (`torch.Tensor`): |
| | Input query states to be passed to Flash Attention API |
| | key_states (`torch.Tensor`): |
| | Input key states to be passed to Flash Attention API |
| | value_states (`torch.Tensor`): |
| | Input value states to be passed to Flash Attention API |
| | attention_mask (`torch.Tensor`): |
| | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the |
| | position of padding tokens and 1 for the position of non-padding tokens. |
| | dropout (`float`): |
| | Attention dropout |
| | softmax_scale (`float`, *optional*): |
| | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) |
| | use_sliding_windows (`bool`, *optional*): |
| | Whether to activate sliding window attention. |
| | """ |
| | if not self._flash_attn_uses_top_left_mask: |
| | causal = self.is_causal |
| | else: |
| | |
| | causal = self.is_causal and query_length != 1 |
| |
|
| | |
| | if attention_mask is not None: |
| | batch_size = query_states.shape[0] |
| | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( |
| | query_states, key_states, value_states, attention_mask, query_length |
| | ) |
| |
|
| | cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
| | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
| |
|
| | if not use_sliding_windows: |
| | attn_output_unpad = flash_attn_varlen_func( |
| | query_states, |
| | key_states, |
| | value_states, |
| | cu_seqlens_q=cu_seqlens_q, |
| | cu_seqlens_k=cu_seqlens_k, |
| | max_seqlen_q=max_seqlen_in_batch_q, |
| | max_seqlen_k=max_seqlen_in_batch_k, |
| | dropout_p=dropout, |
| | softmax_scale=softmax_scale, |
| | causal=causal, |
| | ) |
| | else: |
| | attn_output_unpad = flash_attn_varlen_func( |
| | query_states, |
| | key_states, |
| | value_states, |
| | cu_seqlens_q=cu_seqlens_q, |
| | cu_seqlens_k=cu_seqlens_k, |
| | max_seqlen_q=max_seqlen_in_batch_q, |
| | max_seqlen_k=max_seqlen_in_batch_k, |
| | dropout_p=dropout, |
| | softmax_scale=softmax_scale, |
| | causal=causal, |
| | window_size=(self.config.sliding_window, self.config.sliding_window), |
| | ) |
| |
|
| | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) |
| | else: |
| | if not use_sliding_windows: |
| | attn_output = flash_attn_func( |
| | query_states, |
| | key_states, |
| | value_states, |
| | dropout, |
| | softmax_scale=softmax_scale, |
| | causal=causal, |
| | ) |
| | else: |
| | attn_output = flash_attn_func( |
| | query_states, |
| | key_states, |
| | value_states, |
| | dropout, |
| | softmax_scale=softmax_scale, |
| | causal=causal, |
| | window_size=(self.config.sliding_window, self.config.sliding_window), |
| | ) |
| |
|
| | return attn_output |
| |
|
| | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): |
| | batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape |
| |
|
| | |
| | |
| | if kv_seq_len != attention_mask.shape[-1]: |
| | attention_mask_num_tokens = attention_mask.shape[-1] |
| | attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] |
| |
|
| | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) |
| |
|
| | key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) |
| | value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) |
| |
|
| | if query_length == kv_seq_len: |
| | query_layer = index_first_axis( |
| | query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k |
| | ) |
| | cu_seqlens_q = cu_seqlens_k |
| | max_seqlen_in_batch_q = max_seqlen_in_batch_k |
| | indices_q = indices_k |
| | elif query_length == 1: |
| | max_seqlen_in_batch_q = 1 |
| | cu_seqlens_q = torch.arange( |
| | batch_size + 1, dtype=torch.int32, device=query_layer.device |
| | ) |
| | indices_q = cu_seqlens_q[:-1] |
| | query_layer = query_layer.squeeze(1) |
| | else: |
| | |
| | attention_mask = attention_mask[:, -query_length:] |
| | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) |
| |
|
| | return ( |
| | query_layer, |
| | key_layer, |
| | value_layer, |
| | indices_q, |
| | (cu_seqlens_q, cu_seqlens_k), |
| | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
| | ) |
| |
|
| | IDEFICS2_PERCEIVER_ATTENTION_CLASSES = { |
| | "eager": Idefics2PerceiverAttention, |
| | "flash_attention_2": Idefics2PerceiverFlashAttention2, |
| | } |
| |
|
| |
|
| | class Idefics2MLP(nn.Module): |
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | intermediate_size: int, |
| | output_size: int, |
| | hidden_act: str, |
| | ): |
| | super().__init__() |
| | self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| | self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| | self.down_proj = nn.Linear(intermediate_size, output_size, bias=False) |
| | self.act_fn = ACT2FN[hidden_act] |
| |
|
| | def forward(self, x): |
| | return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
| |
|
| | |
| | class Idefics2RMSNorm(nn.Module): |
| | def __init__(self, hidden_size, eps=1e-6): |
| | """ |
| | Idefics2RMSNorm is equivalent to T5LayerNorm |
| | """ |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones(hidden_size)) |
| | self.variance_epsilon = eps |
| |
|
| | def forward(self, hidden_states): |
| | input_dtype = hidden_states.dtype |
| | hidden_states = hidden_states.to(torch.float32) |
| | variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| | return self.weight * hidden_states.to(input_dtype) |
| |
|
| | class Idefics2PerceiverLayer(nn.Module): |
| | def __init__(self, config, layer_idx: int): |
| | super().__init__() |
| | self.hidden_size = config.hidden_size |
| | self.n_latents = config.resampler_n_latents |
| | self.depth = config.resampler_depth |
| | self.rms_norm_eps = config.rms_norm_eps |
| |
|
| | self.input_latents_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps) |
| | self.input_context_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps) |
| | self.self_attn = IDEFICS2_PERCEIVER_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) |
| | self.post_attention_layernorm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps) |
| | self.mlp = Idefics2MLP( |
| | hidden_size=config.hidden_size, |
| | intermediate_size=config.hidden_size * 4, |
| | output_size=config.hidden_size, |
| | hidden_act=config.hidden_act, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | latents: torch.Tensor, |
| | context: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| | output_attentions: Optional[bool] = False, |
| | use_cache: Optional[bool] = False, |
| | **kwargs, |
| | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| | """ |
| | Args: |
| | latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| | context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size |
| | `(batch, sequence_length)` where padding elements are indicated by 0. |
| | output_attentions (`bool`, *optional*): |
| | Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| | returned tensors for more detail. |
| | use_cache (`bool`, *optional*): |
| | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| | (see `past_key_values`). |
| | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
| | """ |
| | residual = latents |
| |
|
| | latents = self.input_latents_norm(latents) |
| | context = self.input_context_norm(context) |
| |
|
| | latents, self_attn_weights, present_key_value = self.self_attn( |
| | latents=latents, |
| | context=context, |
| | attention_mask=attention_mask, |
| | ) |
| | latents = residual + latents |
| | residual = latents |
| |
|
| | latents = self.post_attention_layernorm(latents) |
| | latents = self.mlp(latents) |
| | latents = residual + latents |
| |
|
| | outputs = (latents,) |
| |
|
| | if output_attentions: |
| | outputs += (self_attn_weights,) |
| |
|
| | if use_cache: |
| | outputs += (present_key_value,) |
| |
|
| | return outputs |
| |
|
| | class Idefics2Qformer(nn.Module): |
| | |
| | def __init__(self, config) -> None: |
| | """ |
| | Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or |
| | MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then |
| | returns a Tensor of shape [bsz, n_latents, embed_dim]. The Resampler acts as a form of learned pooling and |
| | is derived from [Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206). |
| | """ |
| | super().__init__() |
| | config._attn_implementation = "flash_attention_2" |
| | self._use_flash_attention_2 = True |
| | |
| | self.hidden_size = config.hidden_size |
| | self.hidden_act = config.hidden_act |
| | self.n_latents = config.resampler_n_latents |
| | self.depth = config.resampler_depth |
| | self.rms_norm_eps = config.rms_norm_eps |
| |
|
| | |
| | self.latents = nn.Parameter(torch.ones(self.n_latents, self.hidden_size)) |
| | |
| | self.layers = nn.ModuleList([Idefics2PerceiverLayer(config, idx) for idx in range(self.depth)]) |
| | self.norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps) |
| |
|
| | |
| | |
| |
|
| | def forward( |
| | self, |
| | context: torch.Tensor, |
| | attention_mask, |
| | ) -> torch.Tensor: |
| | |
| | latents = self.latents.unsqueeze(0).expand((context.shape[0], *self.latents.size())) |
| |
|
| | latent_attention_mask = torch.ones( |
| | (attention_mask.size(0), latents.size(1)), dtype=attention_mask.dtype, device=attention_mask.device |
| | ) |
| | attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1) |
| | attention_mask = ( |
| | _prepare_4d_attention_mask(attention_mask, latents.dtype, tgt_len=self.n_latents) |
| | if not self._use_flash_attention_2 |
| | else attention_mask |
| | ) |
| | |
| | compressed_context = latents |
| | |
| | for perceiver_layer in self.layers: |
| | layer_outputs = torch.utils.checkpoint.checkpoint( |
| | perceiver_layer.__call__, |
| | compressed_context, |
| | context, |
| | attention_mask, |
| | None, |
| | None, |
| | False, |
| | False, |
| | use_reentrant=True) |
| | compressed_context = layer_outputs[0] |
| | |
| |
|
| | compressed_context = self.norm(compressed_context) |
| |
|
| | return compressed_context |
| | |
| | class Idefics2Connector(PreTrainedModel): |
| | _auto_class = 'AutoModel' |
| | config_class = Idefics2ConnectorConfig |
| | |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.modality_projection = Idefics2MLP( |
| | hidden_size=config.vision_hidden_size, |
| | intermediate_size=config.intermediate_size, |
| | output_size=config.hidden_size, |
| | hidden_act=config.hidden_act, |
| | ) |
| | self.perceiver_resampler = Idefics2Qformer(config) |
| | self.config = config |
| |
|
| | def forward(self, image_hidden_states, attention_mask): |
| | image_hidden_states = self.modality_projection(image_hidden_states) |
| | image_hidden_states = self.perceiver_resampler(context=image_hidden_states, attention_mask=attention_mask) |
| | |
| | vision_hidden_size = image_hidden_states.shape[-1] |
| | num_image = image_hidden_states.shape[0] |
| | reshaped_image_hidden_states = image_hidden_states.view(num_image, -1, vision_hidden_size) |
| | |
| | return reshaped_image_hidden_states |
| | |
| | @classmethod |
| | def from_pretrained(self, config_path): |
| | config = Idefics2ConnectorConfig.from_pretrained(f'{config_path}/config.json') |
| | cls = Idefics2Connector(config=config) |
| | |
| | state_dict = torch.load(f'{config_path}/connector.pth', map_location='cpu') |
| | ret = cls.load_state_dict(state_dict, strict=False) |
| | print("Loading idefics2 Connector from : {}".format(config_path)) |
| | return cls |
| |
|