| |
| |
| |
| |
| @@ -2016,6 +2016,9 @@ def forward( |
| class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): |
| config_class = Blip2Config |
| main_input_name = "pixel_values" |
| + _supports_cache_class = True |
| + _supports_static_cache = True |
| + _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) |
| |
| def __init__(self, config: Blip2Config): |
| super().__init__(config) |
| |
| |
| |
| |
| @@ -1284,13 +1284,13 @@ def forward( |
| |
| if pixel_values is not None: |
| image_tokens = self.get_image_tokens(pixel_values) |
| - n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item() |
| - n_image_features = image_tokens.shape[0] * image_tokens.shape[1] |
| - if n_image_tokens_in_text != n_image_features: |
| + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id |
| + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_tokens.numel(): |
| + n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum() |
| + n_image_features = image_tokens.shape[0] * image_tokens.shape[1] |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}" |
| ) |
| - special_image_mask = input_ids == self.vocabulary_mapping.image_token_id |
| image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) |
| input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) |
| |
| |
| |
| |
| |
| @@ -25,7 +25,7 @@ |
| import torch.nn as nn |
| |
| from ...activations import ACT2FN |
| -from ...cache_utils import Cache, HybridCache |
| +from ...cache_utils import Cache, HybridCache, StaticCache |
| from ...generation import GenerationMixin |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| @@ -701,7 +701,7 @@ def _update_causal_mask( |
| |
| dtype, device = input_tensor.dtype, input_tensor.device |
| sequence_length = input_tensor.shape[1] |
| - if isinstance(past_key_values, HybridCache): |
| + if isinstance(past_key_values, (HybridCache, StaticCache)): |
| target_length = past_key_values.get_max_cache_shape() |
| else: |
| target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] |
| |
| |
| |
| |
| @@ -25,7 +25,7 @@ |
| import torch.nn as nn |
| |
| from ...activations import ACT2FN |
| -from ...cache_utils import Cache, HybridCache |
| +from ...cache_utils import Cache, HybridCache, StaticCache |
| from ...generation import GenerationMixin |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_outputs import ( |
| @@ -713,7 +713,7 @@ def _update_causal_mask( |
| |
| dtype, device = input_tensor.dtype, input_tensor.device |
| sequence_length = input_tensor.shape[1] |
| - if isinstance(past_key_values, HybridCache): |
| + if isinstance(past_key_values, (HybridCache, StaticCache)): |
| target_length = past_key_values.get_max_cache_shape() |
| else: |
| target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] |
| |
| |
| |
| |
| @@ -20,7 +20,7 @@ |
| import torch.utils.checkpoint |
| |
| from ...activations import ACT2FN |
| -from ...cache_utils import Cache, HybridCache |
| +from ...cache_utils import Cache, HybridCache, StaticCache |
| from ...configuration_utils import PretrainedConfig |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| from ...modeling_outputs import ( |
| @@ -550,7 +550,7 @@ def _update_causal_mask( |
| |
| dtype, device = input_tensor.dtype, input_tensor.device |
| sequence_length = input_tensor.shape[1] |
| - if isinstance(past_key_values, HybridCache): |
| + if isinstance(past_key_values, (HybridCache, StaticCache)): |
| target_length = past_key_values.get_max_cache_shape() |
| else: |
| target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] |
| |
| |
| |
| |
| @@ -132,8 +132,6 @@ class GotOcr2Config(PretrainedConfig): |
| The config object or dictionary of the vision backbone. |
| text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): |
| The config object or dictionary of the text backbone. |
| - ignore_index (`int`, *optional*, defaults to -100): |
| - The ignore index for the loss function. |
| image_token_index (`int`, *optional*, defaults to 151859): |
| The image token index to encode the image prompt. |
| image_seq_length (`int`, *optional*, defaults to 576): |
| @@ -161,13 +159,11 @@ def __init__( |
| self, |
| vision_config=None, |
| text_config=None, |
| - ignore_index=-100, |
| image_token_index=151859, |
| image_seq_length=576, |
| pad_token_id=-1, |
| **kwargs, |
| ): |
| - self.ignore_index = ignore_index |
| self.image_token_index = image_token_index |
| self.image_seq_length = image_seq_length |
| self.pad_token_id = pad_token_id |
| |
| |
| |
| |
| @@ -594,6 +594,8 @@ class GotOcr2PreTrainedModel(PreTrainedModel): |
| _supports_cache_class = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| + _supports_quantized_cache = True |
| + _supports_static_cache = True |
| |
| def _init_weights(self, module): |
| # important: this ported version of GotOcr2 isn't meant for training from scratch - only |
| @@ -748,89 +750,6 @@ def get_image_features( |
| image_outputs = self.vision_tower(pixel_values).last_hidden_state |
| return self.multi_modal_projector(image_outputs) |
| |
| - def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): |
| - num_images, num_image_patches, embed_dim = image_features.shape |
| - batch_size, sequence_length = input_ids.shape |
| - left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) |
| - # 1. Create a mask to know where special image tokens are |
| - special_image_token_mask = input_ids == self.config.image_token_index |
| - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) |
| - # Compute the maximum embed dimension |
| - max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length |
| - batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) |
| - |
| - # 2. Compute the positions where text should be written |
| - # Calculate new positions for text tokens in merged image-text sequence. |
| - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. |
| - # `torch.cumsum` computes how each image token shifts subsequent text token positions. |
| - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. |
| - new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 |
| - nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] |
| - if left_padding: |
| - new_token_positions += nb_image_pad[:, None] # offset for left padding |
| - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] |
| - |
| - # 3. Create the full embedding, already padded to the maximum position |
| - final_embedding = torch.zeros( |
| - batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device |
| - ) |
| - final_attention_mask = torch.zeros( |
| - batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device |
| - ) |
| - if labels is not None: |
| - final_labels = torch.full( |
| - (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device |
| - ) |
| - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually |
| - # set the corresponding tensors into their correct target device. |
| - target_device = inputs_embeds.device |
| - batch_indices, non_image_indices, text_to_overwrite = ( |
| - batch_indices.to(target_device), |
| - non_image_indices.to(target_device), |
| - text_to_overwrite.to(target_device), |
| - ) |
| - attention_mask = attention_mask.to(target_device) |
| - |
| - # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] |
| - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features |
| - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] |
| - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] |
| - if labels is not None: |
| - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] |
| - |
| - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) |
| - image_to_overwrite = torch.full( |
| - (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device |
| - ) |
| - image_to_overwrite[batch_indices, text_to_overwrite] = False |
| - if left_padding: |
| - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) |
| - else: |
| - mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 |
| - padding_mask = mask <= new_token_positions[:, -1:].to(target_device) |
| - image_to_overwrite &= padding_mask |
| - |
| - if image_to_overwrite.sum() != image_features.shape[:-1].numel(): |
| - raise ValueError( |
| - f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" |
| - f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." |
| - ) |
| - |
| - final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) |
| - final_attention_mask |= image_to_overwrite |
| - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) |
| - |
| - # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. |
| - batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) |
| - indices_to_mask = new_token_positions[batch_indices, pad_indices] |
| - |
| - final_embedding[batch_indices, indices_to_mask] = 0 |
| - |
| - if labels is None: |
| - final_labels = None |
| - |
| - return final_embedding, final_attention_mask, final_labels, position_ids |
| - |
| @add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING) |
| @replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| def forward( |
| |
| |
| |
| |
| @@ -170,8 +170,6 @@ class GotOcr2Config(PretrainedConfig): |
| The config object or dictionary of the vision backbone. |
| text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): |
| The config object or dictionary of the text backbone. |
| - ignore_index (`int`, *optional*, defaults to -100): |
| - The ignore index for the loss function. |
| image_token_index (`int`, *optional*, defaults to 151859): |
| The image token index to encode the image prompt. |
| image_seq_length (`int`, *optional*, defaults to 576): |
| @@ -199,13 +197,11 @@ def __init__( |
| self, |
| vision_config=None, |
| text_config=None, |
| - ignore_index=-100, |
| image_token_index=151859, |
| image_seq_length=576, |
| pad_token_id=-1, |
| **kwargs, |
| ): |
| - self.ignore_index = ignore_index |
| self.image_token_index = image_token_index |
| self.image_seq_length = image_seq_length |
| self.pad_token_id = pad_token_id |
| |
| |
| |
| |
| @@ -51,7 +51,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): |
| _skip_keys_device_placement = "past_key_values" |
| _supports_cache_class = True |
| _supports_quantized_cache = True |
| - _supports_static_cache = False # TODO (fix me): compilation fails due to a stide error? |
| + _supports_static_cache = True |
| |
| def _init_weights(self, module): |
| """Initialize the weights""" |
| @@ -129,8 +129,8 @@ def forward( |
| |
| cos, sin = position_embeddings |
| query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) |
| - query = torch.cat((query, query_pass), dim=-1) |
| - key = torch.cat((key, key_pass), dim=-1) |
| + query = torch.cat((query, query_pass), dim=-1).contiguous() |
| + key = torch.cat((key, key_pass), dim=-1).contiguous() |
| |
| # Cache QKV values |
| if layer_past is not None: |
| |
| |
| |
| |
| @@ -1108,6 +1108,7 @@ def forward( |
| router_logits=all_router_logits, |
| ) |
| |
| + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask |
| def _update_causal_mask( |
| self, |
| attention_mask: torch.Tensor, |
| @@ -1116,13 +1117,8 @@ def _update_causal_mask( |
| past_key_values: Cache, |
| output_attentions: bool, |
| ): |
| - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static |
| - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. |
| - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using |
| - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 |
| - |
| if self.config._attn_implementation == "flash_attention_2": |
| - if attention_mask is not None and 0.0 in attention_mask: |
| + if attention_mask is not None and (attention_mask == 0.0).any(): |
| return attention_mask |
| return None |
| |
| @@ -1143,7 +1139,6 @@ def _update_causal_mask( |
| return None |
| |
| dtype, device = input_tensor.dtype, input_tensor.device |
| - min_dtype = torch.finfo(dtype).min |
| sequence_length = input_tensor.shape[1] |
| if using_static_cache: |
| target_length = past_key_values.get_max_cache_shape() |
| @@ -1154,25 +1149,17 @@ def _update_causal_mask( |
| else past_seen_tokens + sequence_length + 1 |
| ) |
| |
| - if attention_mask is not None and attention_mask.dim() == 4: |
| - # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing |
| - causal_mask = attention_mask |
| - else: |
| - causal_mask = torch.full( |
| - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
| - ) |
| - if sequence_length != 1: |
| - causal_mask = torch.triu(causal_mask, diagonal=1) |
| - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) |
| - if attention_mask is not None: |
| - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
| - mask_length = attention_mask.shape[-1] |
| - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| - padding_mask = padding_mask == 0 |
| - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| - padding_mask, min_dtype |
| - ) |
| + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). |
| + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask, |
| + sequence_length=sequence_length, |
| + target_length=target_length, |
| + dtype=dtype, |
| + device=device, |
| + cache_position=cache_position, |
| + batch_size=input_tensor.shape[0], |
| + ) |
| + |
| if ( |
| self.config._attn_implementation == "sdpa" |
| and attention_mask is not None |
| @@ -1182,6 +1169,7 @@ def _update_causal_mask( |
| # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when |
| # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. |
| # Details: https://github.com/pytorch/pytorch/issues/110213 |
| + min_dtype = torch.finfo(dtype).min |
| causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
| |
| return causal_mask |
| |
| |
| |
| |
| @@ -1290,6 +1290,9 @@ def forward( |
| class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin): |
| config_class = InstructBlipConfig |
| main_input_name = "pixel_values" |
| + _supports_cache_class = True |
| + _supports_static_cache = True |
| + _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) |
| |
| def __init__(self, config: InstructBlipConfig): |
| super().__init__(config) |
| |
| |
| |
| |
| @@ -1284,6 +1284,9 @@ def forward( |
| class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin): |
| config_class = InstructBlipVideoConfig |
| main_input_name = "pixel_values" |
| + _supports_cache_class = True |
| + _supports_static_cache = True |
| + _supports_quantized_cache = False # not all LM bacbones support (e.g. T5) |
| |
| def __init__(self, config: InstructBlipVideoConfig): |
| super().__init__(config) |
| |
| |
| |
| |
| @@ -37,8 +37,6 @@ class LlavaConfig(PretrainedConfig): |
| The config object or dictionary of the vision backbone. |
| text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): |
| The config object or dictionary of the text backbone. |
| - ignore_index (`int`, *optional*, defaults to -100): |
| - The ignore index for the loss function. |
| image_token_index (`int`, *optional*, defaults to 32000): |
| The image token index to encode the image prompt. |
| projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): |
| @@ -83,7 +81,6 @@ def __init__( |
| self, |
| vision_config=None, |
| text_config=None, |
| - ignore_index=-100, |
| image_token_index=32000, |
| projector_hidden_act="gelu", |
| vision_feature_select_strategy="default", |
| @@ -92,7 +89,6 @@ def __init__( |
| multimodal_projector_bias=True, |
| **kwargs, |
| ): |
| - self.ignore_index = ignore_index |
| self.image_token_index = image_token_index |
| self.projector_hidden_act = projector_hidden_act |
| self.image_seq_length = image_seq_length |
| |
| |
| |
| |
| @@ -28,6 +28,7 @@ |
| from ...utils import ( |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| + is_torchdynamo_compiling, |
| logging, |
| replace_return_docstrings, |
| ) |
| @@ -136,6 +137,8 @@ class LlavaPreTrainedModel(PreTrainedModel): |
| _supports_cache_class = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| + _supports_quantized_cache = True |
| + _supports_static_cache = True |
| |
| def _init_weights(self, module): |
| # important: this ported version of Llava isn't meant for training from scratch - only |
| @@ -321,89 +324,6 @@ def get_image_features( |
| image_features = self.multi_modal_projector(selected_image_feature) |
| return image_features |
| |
| - def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): |
| - num_images, num_image_patches, embed_dim = image_features.shape |
| - batch_size, sequence_length = input_ids.shape |
| - left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) |
| - # 1. Create a mask to know where special image tokens are |
| - special_image_token_mask = input_ids == self.config.image_token_index |
| - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) |
| - # Compute the maximum embed dimension |
| - max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length |
| - batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) |
| - |
| - # 2. Compute the positions where text should be written |
| - # Calculate new positions for text tokens in merged image-text sequence. |
| - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. |
| - # `torch.cumsum` computes how each image token shifts subsequent text token positions. |
| - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. |
| - new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 |
| - nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] |
| - if left_padding: |
| - new_token_positions += nb_image_pad[:, None] # offset for left padding |
| - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] |
| - |
| - # 3. Create the full embedding, already padded to the maximum position |
| - final_embedding = torch.zeros( |
| - batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device |
| - ) |
| - final_attention_mask = torch.zeros( |
| - batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device |
| - ) |
| - if labels is not None: |
| - final_labels = torch.full( |
| - (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device |
| - ) |
| - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually |
| - # set the corresponding tensors into their correct target device. |
| - target_device = inputs_embeds.device |
| - batch_indices, non_image_indices, text_to_overwrite = ( |
| - batch_indices.to(target_device), |
| - non_image_indices.to(target_device), |
| - text_to_overwrite.to(target_device), |
| - ) |
| - attention_mask = attention_mask.to(target_device) |
| - |
| - # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] |
| - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features |
| - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] |
| - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] |
| - if labels is not None: |
| - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] |
| - |
| - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) |
| - image_to_overwrite = torch.full( |
| - (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device |
| - ) |
| - image_to_overwrite[batch_indices, text_to_overwrite] = False |
| - if left_padding: |
| - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) |
| - else: |
| - mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 |
| - padding_mask = mask <= new_token_positions[:, -1:].to(target_device) |
| - image_to_overwrite &= padding_mask |
| - |
| - if image_to_overwrite.sum() != image_features.shape[:-1].numel(): |
| - raise ValueError( |
| - f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" |
| - f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." |
| - ) |
| - |
| - final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) |
| - final_attention_mask |= image_to_overwrite |
| - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) |
| - |
| - # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. |
| - batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) |
| - indices_to_mask = new_token_positions[batch_indices, pad_indices] |
| - |
| - final_embedding[batch_indices, indices_to_mask] = 0 |
| - |
| - if labels is None: |
| - final_labels = None |
| - |
| - return final_embedding, final_attention_mask, final_labels, position_ids |
| - |
| @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") |
| @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) |
| @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| @@ -499,14 +419,14 @@ def forward( |
| image_sizes=image_sizes, |
| ) |
| |
| - n_image_tokens = (input_ids == self.config.image_token_index).sum() |
| - n_image_features = image_features.shape[0] * image_features.shape[1] |
| - if n_image_tokens != n_image_features: |
| + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): |
| + n_image_tokens = (input_ids == self.config.image_token_index).sum() |
| + n_image_features = image_features.shape[0] * image_features.shape[1] |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
| - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
| |
| |
| |
| |
| |
| @@ -36,8 +36,6 @@ class LlavaNextConfig(PretrainedConfig): |
| The config object or dictionary of the vision backbone. |
| text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): |
| The config object or dictionary of the text backbone. |
| - ignore_index (`int`, *optional*, defaults to -100): |
| - The ignore index for the loss function. |
| image_token_index (`int`, *optional*, defaults to 32000): |
| The image token index to encode the image prompt. |
| projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): |
| @@ -88,7 +86,6 @@ def __init__( |
| self, |
| vision_config=None, |
| text_config=None, |
| - ignore_index=-100, |
| image_token_index=32000, |
| projector_hidden_act="gelu", |
| vision_feature_select_strategy="default", |
| @@ -99,7 +96,6 @@ def __init__( |
| multimodal_projector_bias=True, |
| **kwargs, |
| ): |
| - self.ignore_index = ignore_index |
| self.image_token_index = image_token_index |
| self.projector_hidden_act = projector_hidden_act |
| self.image_seq_length = image_seq_length |
| |
| |
| |
| |
| @@ -31,6 +31,7 @@ |
| from ...utils import ( |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| + is_torchdynamo_compiling, |
| logging, |
| replace_return_docstrings, |
| ) |
| @@ -245,6 +246,8 @@ class LlavaNextPreTrainedModel(PreTrainedModel): |
| _supports_cache_class = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| + _supports_quantized_cache = True |
| + _supports_static_cache = True |
| |
| def _init_weights(self, module): |
| # important: this ported version of LlavaNext isn't meant for training from scratch - only |
| @@ -405,245 +408,6 @@ def set_decoder(self, decoder): |
| def get_decoder(self): |
| return self.language_model.get_decoder() |
| |
| - def _merge_input_ids_with_image_features( |
| - self, |
| - image_features, |
| - feature_lens, |
| - inputs_embeds, |
| - input_ids, |
| - attention_mask, |
| - position_ids=None, |
| - labels=None, |
| - image_token_index=None, |
| - ignore_index=-100, |
| - ): |
| - """ |
| - Merge input_ids with with image features into final embeddings |
| - |
| - Args: |
| - image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`): |
| - All vision vectors of all images in the batch |
| - feature_lens (`torch.LongTensor` of shape `(num_images)`): |
| - The length of visual embeddings of each image as stacked in `image_features` |
| - inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`): |
| - Token embeddings before merging with visual embeddings |
| - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| - Input_ids of tokens, possibly filled with image token |
| - attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| - Mask to avoid performing attention on padding token indices. |
| - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| - config.n_positions - 1]`. |
| - labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) |
| - :abels need to be recalculated to support training (if provided) |
| - image_token_index (`int`, *optional*) |
| - Token id used to indicate the special "image" token. Defaults to `config.image_token_index` |
| - ignore_index (`int`, *optional*) |
| - Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100. |
| - Returns: |
| - final_embedding, final_attention_mask, position_ids, final_labels |
| - |
| - Explanation: |
| - each image has variable length embeddings, with length specified by feature_lens |
| - image_features is concatenation of all visual embed vectors |
| - task: fill each <image> with the correct number of visual embeddings |
| - Example: |
| - X (5 patches), Y (3 patches), Z (8) |
| - X, Y are in the same sequence (in-context learning) |
| - if right padding |
| - input_ids: [ |
| - a b c d e f X g h i j k Y l m |
| - o p q r Z s t u v _ _ _ _ _ _ |
| - ] |
| - input_ids should be: [ |
| - a b c d e f X X X X X g h i j k Y Y Y l m |
| - o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _ |
| - ] |
| - labels should be: [ |
| - a b c d e f _ _ _ _ _ g h i j k _ _ _ l m |
| - o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _ |
| - ] |
| - elif left padding |
| - input_ids: [ |
| - a b c d e f X g h i j k Y l m |
| - _ _ _ _ _ _ o p q r Z s t u v |
| - ] |
| - input_ids should be: [ |
| - a b c d e f X X X X X g h i j k Y Y Y l m |
| - _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v |
| - ] |
| - labels should be: [ |
| - a b c d e f _ _ _ _ _ g h i j k _ _ _ l m |
| - _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v |
| - ] |
| - Edge cases: |
| - * If tokens are same but image token sizes are different, then cannot infer left or right padding |
| - ```python |
| - cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) |
| - chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw) |
| - prompts = [ |
| - "[INST] <image>\nWhat is shown in this image? [/INST]", |
| - "[INST] <image>\nWhat is shown in this image? [/INST]", |
| - ] |
| - inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda") |
| - chart_img has 2634 tokens, while cat_img has 2340 tokens |
| - ``` |
| - |
| - input_ids: [ |
| - a b c d X g h |
| - i j Y k l m n |
| - ] |
| - where X is 3 tokens while Y is 5, this mean after merge |
| - if left-padding (batched generation) |
| - input_ids should be: [ |
| - _ _ a b c d X X X g h |
| - i j Y Y Y Y Y k l m n |
| - ] |
| - elif (right padding) (training) |
| - input_ids should be: [ |
| - a b c d X X X g h _ _ |
| - i j Y Y Y Y Y k l m n |
| - ] |
| - """ |
| - image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index |
| - ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index |
| - |
| - if self.training and self.padding_side == "left": |
| - logger.warning_once( |
| - "Padding side is set to 'left' but the model is in training mode. For training " |
| - "it is recommended to set `model.padding_side='right' and `processor.tokenizer.padding_side='right'`. " |
| - "If that's intended, ignore this warning" |
| - ) |
| - if not self.training and self.padding_side == "right": |
| - logger.warning_once( |
| - "Padding side is set to 'right' but the model is in inference mode. For correct " |
| - "generation results, please set `model.padding_side='left'` and `processor.tokenizer.padding_side='left'`. " |
| - "If that's intended, ignore this warning" |
| - ) |
| - |
| - with torch.no_grad(): |
| - # ! in llava 1.6, number of patches is variable |
| - num_images = feature_lens.size(0) |
| - num_image_features, embed_dim = image_features.shape |
| - if feature_lens.sum() != num_image_features: |
| - raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}") |
| - batch_size = input_ids.shape[0] |
| - _left_padding = torch.any(attention_mask[:, 0] == 0) |
| - _right_padding = torch.any(attention_mask[:, -1] == 0) |
| - |
| - left_padding = self.padding_side == "left" |
| - if batch_size > 1: |
| - if _left_padding and _right_padding: |
| - raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") |
| - elif _right_padding and left_padding: |
| - left_padding = False |
| - elif _left_padding and not left_padding: |
| - left_padding = True |
| - |
| - # Whether to turn off right padding |
| - # 1. Create a mask to know where special image tokens are |
| - special_image_token_mask = input_ids == image_token_index |
| - # special_image_token_mask: [bsz, seqlen] |
| - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) |
| - # num_special_image_tokens: [bsz] |
| - # Reserve for padding of num_images |
| - total_num_special_image_tokens = torch.sum(special_image_token_mask) |
| - if total_num_special_image_tokens != num_images: |
| - raise ValueError( |
| - f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})." |
| - ) |
| - # Compute the maximum embed dimension |
| - # max_image_feature_lens is max_feature_lens per batch |
| - feature_lens = feature_lens.to(input_ids.device) |
| - feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0) |
| - feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device) |
| - embed_sequence_lengths = ( |
| - (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum |
| - ) |
| - max_embed_dim = embed_sequence_lengths.max() |
| - |
| - batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1)) |
| - # 2. Compute the positions where text should be written |
| - # Calculate new positions for text tokens in merged image-text sequence. |
| - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens. |
| - # `torch.cumsum` computes how each image token shifts subsequent text token positions. |
| - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. |
| - # ! instead of special_image_token_mask * (num_image_patches - 1) |
| - # special_image_token_mask * (num_feature_len - 1) |
| - special_image_token_mask = special_image_token_mask.long() |
| - special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1 |
| - new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1 |
| - if left_padding: |
| - # shift right token positions so that they are ending at the same number |
| - # the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:] |
| - new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:] |
| - |
| - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] |
| - |
| - # 3. Create the full embedding, already padded to the maximum position |
| - final_embedding = torch.zeros( |
| - batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device |
| - ) |
| - final_attention_mask = torch.zeros( |
| - batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device |
| - ) |
| - final_input_ids = torch.full( |
| - (batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device |
| - ) |
| - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually |
| - # set the corresponding tensors into their correct target device. |
| - target_device = inputs_embeds.device |
| - batch_indices, non_image_indices, text_to_overwrite = ( |
| - batch_indices.to(target_device), |
| - non_image_indices.to(target_device), |
| - text_to_overwrite.to(target_device), |
| - ) |
| - attention_mask = attention_mask.to(target_device) |
| - input_ids = input_ids.to(target_device) |
| - |
| - # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] |
| - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features |
| - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] |
| - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] |
| - final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices] |
| - final_labels = None |
| - if labels is not None: |
| - labels = labels.to(target_device) |
| - final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long) |
| - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] |
| - |
| - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) |
| - with torch.no_grad(): |
| - image_to_overwrite = torch.full( |
| - (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device |
| - ) |
| - image_to_overwrite[batch_indices, text_to_overwrite] = False |
| - embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device) |
| - embed_indices = embed_indices.expand(batch_size, max_embed_dim) |
| - embed_seq_lens = embed_sequence_lengths[:, None].to(target_device) |
| - |
| - if left_padding: |
| - # exclude padding on the left |
| - max_embed_dim = max_embed_dim.to(target_device) |
| - val = (max_embed_dim - embed_indices) <= embed_seq_lens |
| - else: |
| - # exclude padding on the right |
| - val = embed_indices < embed_seq_lens |
| - image_to_overwrite &= val |
| - |
| - if image_to_overwrite.sum() != num_image_features: |
| - raise ValueError( |
| - f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. " |
| - f"The number of image tokens is {torch.sum(special_image_token_mask)} while" |
| - f" the number of image given to the model is {num_images}. " |
| - f"This prevents correct indexing and breaks batch generation." |
| - ) |
| - final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) |
| - final_attention_mask |= image_to_overwrite |
| - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) |
| - |
| - return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids |
| - |
| def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): |
| """ |
| Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. |
| @@ -875,14 +639,14 @@ def forward( |
| image_newline=self.image_newline, |
| ) |
| |
| - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() |
| - n_image_features = image_features.shape[0] |
| - if n_image_tokens != n_image_features: |
| + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): |
| + n_image_tokens = (input_ids == self.config.image_token_index).sum() |
| + n_image_features = image_features.shape[0] |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
| - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
| |
| |
| |
| |
| |
| @@ -38,8 +38,6 @@ class LlavaNextVideoConfig(PretrainedConfig): |
| The config object or dictionary of the vision backbone. |
| text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): |
| The config object or dictionary of the text backbone. |
| - ignore_index (`int`, *optional*, defaults to -100): |
| - The ignore index for the loss function. |
| image_token_index (`int`, *optional*, defaults to 32001): |
| The image token index to encode the image prompt. |
| projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): |
| @@ -96,7 +94,6 @@ def __init__( |
| self, |
| vision_config=None, |
| text_config=None, |
| - ignore_index=-100, |
| image_token_index=32001, |
| projector_hidden_act="gelu", |
| multimodal_projector_bias=True, |
| @@ -116,7 +113,6 @@ def __init__( |
| self.spatial_pool_stride = spatial_pool_stride |
| self.image_seq_length = image_seq_length |
| self.video_seq_length = video_seq_length |
| - self.ignore_index = ignore_index |
| self.image_token_index = image_token_index |
| self.projector_hidden_act = projector_hidden_act |
| self.multimodal_projector_bias = multimodal_projector_bias |
| |
| |
| |
| |
| @@ -32,7 +32,13 @@ |
| from ...image_processing_utils import select_best_resolution |
| from ...modeling_outputs import ModelOutput |
| from ...modeling_utils import PreTrainedModel |
| -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings |
| +from ...utils import ( |
| + add_start_docstrings, |
| + add_start_docstrings_to_model_forward, |
| + is_torchdynamo_compiling, |
| + logging, |
| + replace_return_docstrings, |
| +) |
| from ...utils.deprecation import deprecate_kwarg |
| from ..auto import AutoModel, AutoModelForCausalLM |
| from .configuration_llava_next_video import LlavaNextVideoConfig |
| @@ -153,6 +159,8 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): |
| _supports_cache_class = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| + _supports_quantized_cache = True |
| + _supports_static_cache = True |
| |
| def _init_weights(self, module): |
| # important: this ported version of LlavaNextVideo isn't meant for training from scratch - only |
| @@ -440,245 +448,6 @@ def set_decoder(self, decoder): |
| def get_decoder(self): |
| return self.language_model.get_decoder() |
| |
| - def _merge_input_ids_with_image_features( |
| - self, |
| - image_features, |
| - feature_lens, |
| - inputs_embeds, |
| - input_ids, |
| - attention_mask, |
| - position_ids=None, |
| - labels=None, |
| - image_token_index=None, |
| - ignore_index=-100, |
| - ): |
| - """ |
| - Merge input_ids with with image features into final embeddings |
| - |
| - Args: |
| - image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`): |
| - All vision vectors of all images in the batch |
| - feature_lens (`torch.LongTensor` of shape `(num_images)`): |
| - The length of visual embeddings of each image as stacked in `image_features` |
| - inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`): |
| - Token embeddings before merging with visual embeddings |
| - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| - Input_ids of tokens, possibly filled with image token |
| - attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| - Mask to avoid performing attention on padding token indices. |
| - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| - config.n_positions - 1]`. |
| - labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) |
| - :abels need to be recalculated to support training (if provided) |
| - image_token_index (`int`, *optional*) |
| - Token id used to indicate the special "image" token. Defaults to `config.image_token_index` |
| - ignore_index (`int`, *optional*) |
| - Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100. |
| - Returns: |
| - final_embedding, final_attention_mask, position_ids, final_labels |
| - |
| - Explanation: |
| - each image has variable length embeddings, with length specified by feature_lens |
| - image_features is concatenation of all visual embed vectors |
| - task: fill each <image> with the correct number of visual embeddings |
| - Example: |
| - X (5 patches), Y (3 patches), Z (8) |
| - X, Y are in the same sequence (in-context learning) |
| - if right padding |
| - input_ids: [ |
| - a b c d e f X g h i j k Y l m |
| - o p q r Z s t u v _ _ _ _ _ _ |
| - ] |
| - input_ids should be: [ |
| - a b c d e f X X X X X g h i j k Y Y Y l m |
| - o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _ |
| - ] |
| - labels should be: [ |
| - a b c d e f _ _ _ _ _ g h i j k _ _ _ l m |
| - o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _ |
| - ] |
| - elif left padding |
| - input_ids: [ |
| - a b c d e f X g h i j k Y l m |
| - _ _ _ _ _ _ o p q r Z s t u v |
| - ] |
| - input_ids should be: [ |
| - a b c d e f X X X X X g h i j k Y Y Y l m |
| - _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v |
| - ] |
| - labels should be: [ |
| - a b c d e f _ _ _ _ _ g h i j k _ _ _ l m |
| - _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v |
| - ] |
| - Edge cases: |
| - * If tokens are same but image token sizes are different, then cannot infer left or right padding |
| - ```python |
| - cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) |
| - chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw) |
| - prompts = [ |
| - "[INST] <image>\nWhat is shown in this image? [/INST]", |
| - "[INST] <image>\nWhat is shown in this image? [/INST]", |
| - ] |
| - inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda") |
| - chart_img has 2634 tokens, while cat_img has 2340 tokens |
| - ``` |
| - |
| - input_ids: [ |
| - a b c d X g h |
| - i j Y k l m n |
| - ] |
| - where X is 3 tokens while Y is 5, this mean after merge |
| - if left-padding (batched generation) |
| - input_ids should be: [ |
| - _ _ a b c d X X X g h |
| - i j Y Y Y Y Y k l m n |
| - ] |
| - elif (right padding) (training) |
| - input_ids should be: [ |
| - a b c d X X X g h _ _ |
| - i j Y Y Y Y Y k l m n |
| - ] |
| - """ |
| - image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index |
| - ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index |
| - |
| - if self.training and self.padding_side == "left": |
| - logger.warning_once( |
| - "Padding side is set to 'left' but the model is in training mode. For training " |
| - "it is recommended to set `model.padding_side='right' and `processor.tokenizer.padding_side='right'`. " |
| - "If that's intended, ignore this warning" |
| - ) |
| - if not self.training and self.padding_side == "right": |
| - logger.warning_once( |
| - "Padding side is set to 'right' but the model is in inference mode. For correct " |
| - "generation results, please set `model.padding_side='left'` and `processor.tokenizer.padding_side='left'`. " |
| - "If that's intended, ignore this warning" |
| - ) |
| - |
| - with torch.no_grad(): |
| - # ! in llava 1.6, number of patches is variable |
| - num_images = feature_lens.size(0) |
| - num_image_features, embed_dim = image_features.shape |
| - if feature_lens.sum() != num_image_features: |
| - raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}") |
| - batch_size = input_ids.shape[0] |
| - _left_padding = torch.any(attention_mask[:, 0] == 0) |
| - _right_padding = torch.any(attention_mask[:, -1] == 0) |
| - |
| - left_padding = self.padding_side == "left" |
| - if batch_size > 1: |
| - if _left_padding and _right_padding: |
| - raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") |
| - elif _right_padding and left_padding: |
| - left_padding = False |
| - elif _left_padding and not left_padding: |
| - left_padding = True |
| - |
| - # Whether to turn off right padding |
| - # 1. Create a mask to know where special image tokens are |
| - special_image_token_mask = input_ids == image_token_index |
| - # special_image_token_mask: [bsz, seqlen] |
| - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) |
| - # num_special_image_tokens: [bsz] |
| - # Reserve for padding of num_images |
| - total_num_special_image_tokens = torch.sum(special_image_token_mask) |
| - if total_num_special_image_tokens != num_images: |
| - raise ValueError( |
| - f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})." |
| - ) |
| - # Compute the maximum embed dimension |
| - # max_image_feature_lens is max_feature_lens per batch |
| - feature_lens = feature_lens.to(input_ids.device) |
| - feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0) |
| - feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device) |
| - embed_sequence_lengths = ( |
| - (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum |
| - ) |
| - max_embed_dim = embed_sequence_lengths.max() |
| - |
| - batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1)) |
| - # 2. Compute the positions where text should be written |
| - # Calculate new positions for text tokens in merged image-text sequence. |
| - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens. |
| - # `torch.cumsum` computes how each image token shifts subsequent text token positions. |
| - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. |
| - # ! instead of special_image_token_mask * (num_image_patches - 1) |
| - # special_image_token_mask * (num_feature_len - 1) |
| - special_image_token_mask = special_image_token_mask.long() |
| - special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1 |
| - new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1 |
| - if left_padding: |
| - # shift right token positions so that they are ending at the same number |
| - # the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:] |
| - new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:] |
| - |
| - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] |
| - |
| - # 3. Create the full embedding, already padded to the maximum position |
| - final_embedding = torch.zeros( |
| - batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device |
| - ) |
| - final_attention_mask = torch.zeros( |
| - batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device |
| - ) |
| - final_input_ids = torch.full( |
| - (batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device |
| - ) |
| - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually |
| - # set the corresponding tensors into their correct target device. |
| - target_device = inputs_embeds.device |
| - batch_indices, non_image_indices, text_to_overwrite = ( |
| - batch_indices.to(target_device), |
| - non_image_indices.to(target_device), |
| - text_to_overwrite.to(target_device), |
| - ) |
| - attention_mask = attention_mask.to(target_device) |
| - input_ids = input_ids.to(target_device) |
| - |
| - # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] |
| - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features |
| - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] |
| - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] |
| - final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices] |
| - final_labels = None |
| - if labels is not None: |
| - labels = labels.to(target_device) |
| - final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long) |
| - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] |
| - |
| - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) |
| - with torch.no_grad(): |
| - image_to_overwrite = torch.full( |
| - (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device |
| - ) |
| - image_to_overwrite[batch_indices, text_to_overwrite] = False |
| - embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device) |
| - embed_indices = embed_indices.expand(batch_size, max_embed_dim) |
| - embed_seq_lens = embed_sequence_lengths[:, None].to(target_device) |
| - |
| - if left_padding: |
| - # exclude padding on the left |
| - max_embed_dim = max_embed_dim.to(target_device) |
| - val = (max_embed_dim - embed_indices) <= embed_seq_lens |
| - else: |
| - # exclude padding on the right |
| - val = embed_indices < embed_seq_lens |
| - image_to_overwrite &= val |
| - |
| - if image_to_overwrite.sum() != num_image_features: |
| - raise ValueError( |
| - f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. " |
| - f"The number of image tokens is {torch.sum(special_image_token_mask)} while" |
| - f" the number of image given to the model is {num_images}. " |
| - f"This prevents correct indexing and breaks batch generation." |
| - ) |
| - final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) |
| - final_attention_mask |= image_to_overwrite |
| - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) |
| - |
| - return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids |
| - |
| def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None): |
| """ |
| Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. |
| @@ -948,14 +717,14 @@ def forward( |
| image_newline=self.image_newline, |
| ) |
| |
| - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() |
| - n_image_features = image_features.shape[0] |
| - if n_image_tokens != n_image_features: |
| + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): |
| + n_image_tokens = (input_ids == self.config.image_token_index).sum() |
| + n_image_features = image_features.shape[0] |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
| - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
| |
| @@ -970,14 +739,14 @@ def forward( |
| video_features = torch.cat(video_features, dim=0) |
| video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) |
| |
| - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() |
| - n_video_features = video_features.shape[0] |
| - if n_video_tokens != n_video_features: |
| + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) |
| + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): |
| + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() |
| + n_video_features = video_features.shape[0] |
| raise ValueError( |
| f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
| ) |
| - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) |
| - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) |
| |
| |
| |
| |
| |
| @@ -30,6 +30,7 @@ |
| |
| from ...configuration_utils import PretrainedConfig |
| from ...utils import ( |
| + is_torchdynamo_compiling, |
| logging, |
| ) |
| from ..auto import CONFIG_MAPPING, AutoConfig |
| @@ -52,8 +53,6 @@ class LlavaNextVideoConfig(PretrainedConfig): |
| The config object or dictionary of the vision backbone. |
| text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): |
| The config object or dictionary of the text backbone. |
| - ignore_index (`int`, *optional*, defaults to -100): |
| - The ignore index for the loss function. |
| image_token_index (`int`, *optional*, defaults to 32001): |
| The image token index to encode the image prompt. |
| projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): |
| @@ -110,7 +109,6 @@ def __init__( |
| self, |
| vision_config=None, |
| text_config=None, |
| - ignore_index=-100, |
| image_token_index=32001, |
| projector_hidden_act="gelu", |
| multimodal_projector_bias=True, |
| @@ -130,7 +128,6 @@ def __init__( |
| self.spatial_pool_stride = spatial_pool_stride |
| self.image_seq_length = image_seq_length |
| self.video_seq_length = video_seq_length |
| - self.ignore_index = ignore_index |
| self.image_token_index = image_token_index |
| self.projector_hidden_act = projector_hidden_act |
| self.multimodal_projector_bias = multimodal_projector_bias |
| @@ -479,14 +476,14 @@ def forward( |
| image_newline=self.image_newline, |
| ) |
| |
| - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() |
| - n_image_features = image_features.shape[0] |
| - if n_image_tokens != n_image_features: |
| + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): |
| + n_image_tokens = (input_ids == self.config.image_token_index).sum() |
| + n_image_features = image_features.shape[0] |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
| - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
| |
| @@ -501,14 +498,14 @@ def forward( |
| video_features = torch.cat(video_features, dim=0) |
| video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) |
| |
| - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() |
| - n_video_features = video_features.shape[0] |
| - if n_video_tokens != n_video_features: |
| + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) |
| + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): |
| + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() |
| + n_video_features = video_features.shape[0] |
| raise ValueError( |
| f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
| ) |
| - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) |
| - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) |
| |
| |
| |
| |
| |
| @@ -30,6 +30,7 @@ |
| from ...modeling_utils import PreTrainedModel |
| from ...utils import ( |
| add_start_docstrings, |
| + is_torchdynamo_compiling, |
| logging, |
| ) |
| from ...utils.deprecation import deprecate_kwarg |
| @@ -250,7 +251,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn_2 = True |
| _supports_cache_class = True |
| - _supports_static_cache = False # Qwen2 doesn't but llava has no reasons to not support |
| + _supports_static_cache = True |
| _supports_quantized_cache = True |
| _supports_sdpa = True |
| |
| @@ -712,19 +713,15 @@ def forward( |
| image_newline=self.image_newline, |
| vision_aspect_ratio=vision_aspect_ratio, |
| ) |
| - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() |
| - n_image_features = image_features.shape[0] |
| |
| - if n_image_tokens != n_image_features: |
| + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): |
| + n_image_tokens = (input_ids == self.config.image_token_index).sum() |
| + n_image_features = image_features.shape[0] |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
| - special_image_mask = ( |
| - (input_ids == self.config.image_token_index) |
| - .unsqueeze(-1) |
| - .expand_as(inputs_embeds) |
| - .to(inputs_embeds.device) |
| - ) |
| image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
| |
| @@ -741,18 +738,14 @@ def forward( |
| video_features = torch.cat((video_features, image_newline), dim=1) |
| video_features = video_features.flatten(0, 1) |
| |
| - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() |
| - n_video_features = video_features.shape[0] |
| - if n_video_tokens != n_video_features: |
| + special_video_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) |
| + special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): |
| + n_video_tokens = (input_ids == self.config.video_token_index).sum() |
| + n_video_features = video_features.shape[0] |
| raise ValueError( |
| f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
| ) |
| - special_video_mask = ( |
| - (input_ids == self.config.video_token_index) |
| - .unsqueeze(-1) |
| - .expand_as(inputs_embeds) |
| - .to(inputs_embeds.device) |
| - ) |
| video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) |
| |
| |
| |
| |
| |
| @@ -22,10 +22,10 @@ |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| |
| from ...activations import ACT2FN |
| +from ...cache_utils import Cache, DynamicCache, StaticCache |
| from ...generation import GenerationMixin |
| from ...modeling_attn_mask_utils import ( |
| - _prepare_4d_causal_attention_mask, |
| - _prepare_4d_causal_attention_mask_for_sdpa, |
| + AttentionMaskConverter, |
| ) |
| from ...modeling_outputs import ( |
| BaseModelOutputWithPast, |
| @@ -98,6 +98,7 @@ class OPTAttention(nn.Module): |
| def __init__( |
| self, |
| config: OPTConfig, |
| + layer_idx: int = None, |
| **kwargs, |
| ): |
| super().__init__() |
| @@ -106,6 +107,13 @@ def __init__( |
| self.num_heads = config.num_attention_heads |
| self.dropout = config.attention_dropout |
| self.enable_bias = config.enable_bias |
| + self.layer_idx = layer_idx |
| + if layer_idx is None: |
| + logger.warning_once( |
| + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " |
| + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " |
| + "when creating this class." |
| + ) |
| |
| self.head_dim = self.embed_dim // self.num_heads |
| self.is_causal = True |
| @@ -122,9 +130,6 @@ def __init__( |
| self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) |
| self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) |
| |
| - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor: |
| - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
| - |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| @@ -134,52 +139,33 @@ def forward( |
| output_attentions: bool = False, |
| # isn't needed in normal attention, but needed in flash attention so to keep the signature same |
| position_ids: Optional[torch.Tensor] = None, |
| - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| + cache_position: Optional[torch.Tensor] = None, |
| + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: |
| """Input shape: Batch x Time x Channel""" |
| bsz, tgt_len, _ = hidden_states.size() |
| |
| # get query proj |
| query_states = self.q_proj(hidden_states) * self.scaling |
| - # get key, value proj |
| - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
| - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
| - if past_key_value is not None: |
| - # reuse k, v, self_attention |
| - key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| - value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| - |
| - past_key_value = (key_states, value_states) |
| + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| |
| - proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
| - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) |
| - key_states = key_states.view(*proj_shape) |
| - value_states = value_states.view(*proj_shape) |
| + key_states = self.k_proj(hidden_states) |
| + value_states = self.v_proj(hidden_states) |
| + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| |
| - src_len = key_states.size(1) |
| - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
| - |
| - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
| - raise ValueError( |
| - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" |
| - f" {attn_weights.size()}" |
| + if past_key_value is not None: |
| + # save all key/value_states to cache to be re-used for fast auto-regressive generation |
| + key_states, value_states = past_key_value.update( |
| + key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
| ) |
| |
| + attn_weights = torch.matmul(query_states, key_states.transpose(3, 2)) |
| if attention_mask is not None: |
| - if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
| - raise ValueError( |
| - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" |
| - ) |
| - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask |
| - attn_weights = torch.max( |
| - attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) |
| - ) |
| - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
| + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| + attn_weights = attn_weights + causal_mask |
| |
| # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 |
| - if attn_weights.dtype == torch.float16: |
| - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) |
| - else: |
| - attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
| + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| |
| if layer_head_mask is not None: |
| if layer_head_mask.size() != (self.num_heads,): |
| @@ -187,39 +173,19 @@ def forward( |
| f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" |
| f" {layer_head_mask.size()}" |
| ) |
| - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
| - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
| - |
| - if output_attentions: |
| - # this operation is a bit awkward, but it's required to |
| - # make sure that attn_weights keeps its gradient. |
| - # In order to do so, attn_weights have to be reshaped |
| - # twice and have to be reused in the following |
| - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
| - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) |
| - else: |
| - attn_weights_reshaped = None |
| + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights |
| |
| attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
| + attn_output = torch.matmul(attn_probs, value_states) |
| |
| - attn_output = torch.bmm(attn_probs, value_states) |
| - |
| - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): |
| - raise ValueError( |
| - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" |
| - f" {attn_output.size()}" |
| - ) |
| - |
| - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) |
| - attn_output = attn_output.transpose(1, 2) |
| + attn_output = attn_output.transpose(1, 2).contiguous() |
| |
| # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be |
| # partitioned aross GPUs when using tensor-parallelism. |
| attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) |
| - |
| attn_output = self.out_proj(attn_output) |
| |
| - return attn_output, attn_weights_reshaped, past_key_value |
| + return attn_output, attn_probs, past_key_value |
| |
| |
| class OptFlashAttention2(OPTAttention): |
| @@ -245,33 +211,33 @@ def forward( |
| layer_head_mask: Optional[torch.Tensor] = None, |
| output_attentions: bool = False, |
| position_ids: Optional[torch.Tensor] = None, |
| + cache_position: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| """Input shape: Batch x Time x Channel""" |
| - bsz, _, _ = hidden_states.size() |
| |
| - # get query proj |
| - query_states = self.q_proj(hidden_states) |
| - # get key, value proj |
| - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
| - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
| - if past_key_value is not None: |
| - # reuse k, v, self_attention |
| - key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| - value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| + bsz, query_length, _ = hidden_states.size() |
| |
| - past_key_value = (key_states, value_states) |
| + query_states = self.q_proj(hidden_states) |
| + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) |
| |
| - query_length = query_states.shape[1] |
| - tgt_len = key_states.shape[-2] |
| + key_states = self.k_proj(hidden_states) |
| + value_states = self.v_proj(hidden_states) |
| + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| |
| - # Flash attention requires the input to have the shape |
| - # batch_size x seq_length x head_dim x hidden_dim |
| - query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) |
| - key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) |
| - value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) |
| + if past_key_value is not None: |
| + # save all key/value_states to cache to be re-used for fast auto-regressive generation |
| + key_states, value_states = past_key_value.update( |
| + key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
| + ) |
| |
| attn_dropout = self.dropout if self.training else 0.0 |
| |
| + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache |
| + # to be able to avoid many of these transpose/reshape/view. |
| + key_states = key_states.transpose(1, 2) |
| + value_states = value_states.transpose(1, 2) |
| + |
| # In PEFT, usually we cast the layer norms in float32 for training stability reasons |
| # therefore the input hidden states gets silently casted in float32. Hence, we need |
| # cast them back in float16 just to be sure everything works as expected. |
| @@ -331,6 +297,7 @@ def forward( |
| layer_head_mask: Optional[torch.Tensor] = None, |
| output_attentions: bool = False, |
| position_ids: Optional[torch.Tensor] = None, |
| + cache_position: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| if output_attentions or layer_head_mask is not None: |
| logger.warning_once( |
| @@ -344,24 +311,24 @@ def forward( |
| layer_head_mask=layer_head_mask, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| - ) # TODO after merge add position_ids=position_ids |
| + cache_position=cache_position, |
| + ) |
| |
| bsz, q_len, _ = hidden_states.size() |
| |
| - query_states = self.q_proj(hidden_states) * self.scaling |
| - query_states = self._shape(query_states, -1, bsz) |
| - |
| - # get key, value proj |
| - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
| - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
| - if past_key_value is not None: |
| - # reuse k, v, self_attention |
| - key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| - value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| + query_states = self.q_proj(hidden_states) |
| + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| |
| - past_key_value = (key_states, value_states) |
| + key_states = self.k_proj(hidden_states) |
| + value_states = self.v_proj(hidden_states) |
| + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| |
| - # shape now is (bsz, num_heads, seq_len, head_dim), all are continuous |
| + if past_key_value is not None: |
| + # save all key/value_states to cache to be re-used for fast auto-regressive generation |
| + key_states, value_states = past_key_value.update( |
| + key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
| + ) |
| |
| causal_mask = attention_mask |
| if attention_mask is not None: |
| @@ -378,10 +345,6 @@ def forward( |
| attn_mask=causal_mask, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=is_causal, |
| - # this model uses the scaling factor in the query projection for some reason, but not in Q@K^T |
| - # so we need to scale to remove scaling in SDPA to have similar results with eager. |
| - # Maybe needs a change in the model to remove scaling in query projection |
| - scale=1.0, |
| ) |
| |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| @@ -399,11 +362,11 @@ def forward( |
| |
| |
| class OPTDecoderLayer(nn.Module): |
| - def __init__(self, config: OPTConfig): |
| + def __init__(self, config: OPTConfig, layer_idx: int = None): |
| super().__init__() |
| self.embed_dim = config.hidden_size |
| |
| - self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config) |
| + self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) |
| |
| self.do_layer_norm_before = config.do_layer_norm_before |
| self.dropout = config.dropout |
| @@ -425,6 +388,7 @@ def forward( |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| position_ids: Optional[torch.LongTensor] = None, |
| + cache_position: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| """ |
| Args: |
| @@ -440,6 +404,8 @@ def forward( |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| (see `past_key_values`). |
| past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
| + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| + Indices depicting the position of the input sequence tokens in the sequence.. |
| """ |
| |
| residual = hidden_states |
| @@ -456,6 +422,7 @@ def forward( |
| attention_mask=attention_mask, |
| layer_head_mask=layer_head_mask, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| hidden_states = residual + hidden_states |
| @@ -524,6 +491,9 @@ class OPTPreTrainedModel(PreTrainedModel): |
| _no_split_modules = ["OPTDecoderLayer"] |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| + _supports_cache_class = True |
| + _supports_quantized_cache = True |
| + _supports_static_cache = True |
| |
| def _init_weights(self, module): |
| std = self.config.init_std |
| @@ -601,6 +571,10 @@ def _init_weights(self, module): |
| config.n_positions - 1]`. for padding use -1. |
| |
| [What are position IDs?](../glossary#position-ids) |
| + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, |
| + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer |
| + the complete sequence length. |
| """ |
| |
| |
| @@ -643,9 +617,7 @@ def __init__(self, config: OPTConfig): |
| else: |
| self.final_layer_norm = None |
| |
| - self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) |
| - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" |
| - self._use_sdpa = config._attn_implementation == "sdpa" |
| + self.layers = nn.ModuleList([OPTDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) |
| |
| self.gradient_checkpointing = False |
| # Initialize weights and apply final processing |
| @@ -657,48 +629,130 @@ def get_input_embeddings(self): |
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
| |
| + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask |
| def _update_causal_mask( |
| self, |
| - inputs_embeds: torch.Tensor, |
| - input_shape: Tuple[int, int], |
| - past_key_values_length: int, |
| - attention_mask: Optional[torch.Tensor] = None, |
| - head_mask: Optional[torch.Tensor] = None, |
| - output_attentions: Optional[bool] = None, |
| + attention_mask: torch.Tensor, |
| + input_tensor: torch.Tensor, |
| + cache_position: torch.Tensor, |
| + past_key_values: Cache, |
| + output_attentions: bool, |
| ): |
| - """ |
| - Updates the causal mask for the decoder. |
| - """ |
| - batch_size, seq_length = input_shape |
| - mask_seq_length = past_key_values_length + seq_length |
| - if self._use_flash_attention_2: |
| - # 2d mask is passed through the layers |
| - causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None |
| - attention_mask = ( |
| - torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| - if attention_mask is None |
| - else attention_mask |
| + if self.config._attn_implementation == "flash_attention_2": |
| + if attention_mask is not None and (attention_mask == 0.0).any(): |
| + return attention_mask |
| + return None |
| + |
| + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in |
| + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail |
| + # to infer the attention mask. |
| + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + using_static_cache = isinstance(past_key_values, StaticCache) |
| + |
| + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward |
| + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: |
| + if AttentionMaskConverter._ignore_causal_mask_sdpa( |
| + attention_mask, |
| + inputs_embeds=input_tensor, |
| + past_key_values_length=past_seen_tokens, |
| + is_training=self.training, |
| + ): |
| + return None |
| + |
| + dtype, device = input_tensor.dtype, input_tensor.device |
| + sequence_length = input_tensor.shape[1] |
| + if using_static_cache: |
| + target_length = past_key_values.get_max_cache_shape() |
| + else: |
| + target_length = ( |
| + attention_mask.shape[-1] |
| + if isinstance(attention_mask, torch.Tensor) |
| + else past_seen_tokens + sequence_length + 1 |
| ) |
| |
| - return causal_attention_mask, attention_mask |
| + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). |
| + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask, |
| + sequence_length=sequence_length, |
| + target_length=target_length, |
| + dtype=dtype, |
| + device=device, |
| + cache_position=cache_position, |
| + batch_size=input_tensor.shape[0], |
| + ) |
| |
| - if attention_mask is None: |
| - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| - elif attention_mask.shape[1] != mask_seq_length: |
| - raise ValueError( |
| - f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " |
| - f"{mask_seq_length} (sum of the lengths of current and past inputs)" |
| - ) |
| - if self._use_sdpa and not output_attentions and head_mask is None: |
| - causal_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
| - attention_mask, input_shape, inputs_embeds, past_key_values_length |
| - ) |
| + if ( |
| + self.config._attn_implementation == "sdpa" |
| + and attention_mask is not None |
| + and attention_mask.device.type in ["cuda", "xpu"] |
| + and not output_attentions |
| + ): |
| + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when |
| + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. |
| + # Details: https://github.com/pytorch/pytorch/issues/110213 |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
| + |
| + return causal_mask |
| + |
| + @staticmethod |
| + # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position |
| + def _prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask: torch.Tensor, |
| + sequence_length: int, |
| + target_length: int, |
| + dtype: torch.dtype, |
| + device: torch.device, |
| + cache_position: torch.Tensor, |
| + batch_size: int, |
| + **kwargs, |
| + ): |
| + """ |
| + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| + |
| + Args: |
| + attention_mask (`torch.Tensor`): |
| + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| + `(batch_size, 1, query_length, key_value_length)`. |
| + sequence_length (`int`): |
| + The sequence length being processed. |
| + target_length (`int`): |
| + The target length: when generating with static cache, the mask should be as long as the static cache, |
| + to account for the 0 padding, the part of the cache that is not filled yet. |
| + dtype (`torch.dtype`): |
| + The dtype to use for the 4D attention mask. |
| + device (`torch.device`): |
| + The device to plcae the 4D attention mask on. |
| + cache_position (`torch.Tensor`): |
| + Indices depicting the position of the input sequence tokens in the sequence. |
| + batch_size (`torch.Tensor`): |
| + Batch size. |
| + """ |
| + if attention_mask is not None and attention_mask.dim() == 4: |
| + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
| + causal_mask = attention_mask |
| else: |
| - causal_attention_mask = _prepare_4d_causal_attention_mask( |
| - attention_mask, input_shape, inputs_embeds, past_key_values_length |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = torch.full( |
| + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
| ) |
| + if sequence_length != 1: |
| + causal_mask = torch.triu(causal_mask, diagonal=1) |
| + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| + if attention_mask is not None: |
| + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
| + mask_length = attention_mask.shape[-1] |
| + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
| + causal_mask.device |
| + ) |
| + padding_mask = padding_mask == 0 |
| + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| + padding_mask, min_dtype |
| + ) |
| |
| - return causal_attention_mask, attention_mask |
| + return causal_mask |
| |
| def forward( |
| self, |
| @@ -712,6 +766,7 @@ def forward( |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| + cache_position: Optional[torch.Tensor] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| r""" |
| Args: |
| @@ -764,6 +819,10 @@ def forward( |
| config.n_positions - 1]`. for padding use -1. |
| |
| [What are position IDs?](../glossary#position-ids) |
| + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, |
| + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer |
| + the complete sequence length. |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| @@ -773,51 +832,65 @@ def forward( |
| |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
| - # retrieve input_ids and inputs_embeds |
| - if input_ids is not None and inputs_embeds is not None: |
| - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
| - elif input_ids is not None: |
| - input_shape = input_ids.size() |
| - input_ids = input_ids.view(-1, input_shape[-1]) |
| - elif inputs_embeds is not None: |
| - input_shape = inputs_embeds.size()[:-1] |
| - else: |
| - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") |
| + if (input_ids is None) ^ (inputs_embeds is not None): |
| + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
| + |
| + if self.gradient_checkpointing and self.training and use_cache: |
| + logger.warning_once( |
| + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
| + ) |
| + use_cache = False |
| + |
| + if input_ids is not None: |
| + input_ids = input_ids.view(-1, input_ids.shape[-1]) |
| |
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
| |
| - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 |
| + return_legacy_cache = False |
| + if use_cache and not isinstance(past_key_values, Cache): |
| + return_legacy_cache = True |
| + past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
| + if past_key_values is None: |
| + logger.warning_once( |
| + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. " |
| + "You should pass an instance of `DynamicCache` instead, e.g. " |
| + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." |
| + ) |
| + |
| + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + if cache_position is None: |
| + cache_position = torch.arange( |
| + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| + ) |
| + |
| + if attention_mask is None: |
| + seq_length = past_seen_tokens + inputs_embeds.shape[1] |
| + attention_mask = torch.ones(inputs_embeds.shape[0], seq_length, device=inputs_embeds.device) |
| |
| - causal_attention_mask, attention_mask = self._update_causal_mask( |
| - inputs_embeds, input_shape, past_key_values_length, attention_mask, head_mask, output_attentions |
| + causal_mask = self._update_causal_mask( |
| + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
| ) |
| - # embed positions |
| |
| + # embed positions |
| if position_ids is None: |
| + # position_ids = cache_position.unsqueeze(0) |
| position_ids = torch.cumsum(attention_mask, dim=1) |
| position_ids = (position_ids * attention_mask - 1).long() |
| - # cut positions if `past_key_values_length` is > 0 |
| - position_ids = position_ids[:, past_key_values_length:] |
| + # cut positions if `past_seen_tokens` is > 0 |
| + position_ids = position_ids[:, past_seen_tokens:] |
| |
| - pos_embeds = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids) |
| + pos_embeds = self.embed_positions(attention_mask, past_seen_tokens, position_ids=position_ids) |
| |
| if self.project_in is not None: |
| inputs_embeds = self.project_in(inputs_embeds) |
| |
| hidden_states = inputs_embeds + pos_embeds.to(inputs_embeds.device) |
| |
| - if self.gradient_checkpointing and self.training: |
| - if use_cache: |
| - logger.warning_once( |
| - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| - ) |
| - use_cache = False |
| - |
| # decoder layers |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| - next_decoder_cache = () if use_cache else None |
| + next_decoder_cache = None |
| |
| # check if head_mask has a correct number of layers specified if desired |
| for attn_mask, mask_name in zip([head_mask], ["head_mask"]): |
| @@ -838,34 +911,34 @@ def forward( |
| if dropout_probability < self.layerdrop: |
| continue |
| |
| - past_key_value = past_key_values[idx] if past_key_values is not None else None |
| - |
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| - causal_attention_mask, |
| + causal_mask, |
| head_mask[idx] if head_mask is not None else None, |
| None, |
| output_attentions, |
| use_cache, |
| position_ids, |
| + cache_position, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| - attention_mask=causal_attention_mask, |
| + attention_mask=causal_mask, |
| position_ids=position_ids, |
| layer_head_mask=(head_mask[idx] if head_mask is not None else None), |
| - past_key_value=past_key_value, |
| + past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| + cache_position=cache_position, |
| ) |
| |
| hidden_states = layer_outputs[0] |
| |
| if use_cache: |
| - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) |
| + next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
| |
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
| @@ -881,6 +954,9 @@ def forward( |
| all_hidden_states += (hidden_states,) |
| |
| next_cache = next_decoder_cache if use_cache else None |
| + if return_legacy_cache: |
| + next_cache = next_cache.to_legacy_cache() |
| + |
| if not return_dict: |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| return BaseModelOutputWithPast( |
| @@ -930,6 +1006,7 @@ def forward( |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| + cache_position: Optional[torch.Tensor] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| @@ -950,6 +1027,7 @@ def forward( |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| if not return_dict: |
| @@ -1008,6 +1086,7 @@ def forward( |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| + cache_position: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| @@ -1069,6 +1148,10 @@ def forward( |
| config.n_positions - 1]`. for padding use -1. |
| |
| [What are position IDs?](../glossary#position-ids) |
| + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, |
| + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer |
| + the complete sequence length. |
| |
| Returns: |
| |
| @@ -1107,6 +1190,7 @@ def forward( |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| logits = self.lm_head(outputs[0]).contiguous() |
| |
| |
| |
| |
| @@ -29,6 +29,7 @@ |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_flash_attn_2_available, |
| + is_torchdynamo_compiling, |
| logging, |
| replace_return_docstrings, |
| ) |
| @@ -508,7 +509,7 @@ def forward( |
| |
| special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| - if inputs_embeds[special_image_mask].numel() != image_features.numel(): |
| + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): |
| image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index) |
| raise ValueError( |
| f"Number of images does not match number of special image tokens in the input text. " |
| |
| |
| |
| |
| @@ -38,8 +38,6 @@ class VideoLlavaConfig(PretrainedConfig): |
| text_config (`Union[AutoConfig, dict]`, *optional*): |
| The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. |
| Defaults to `LlamaConfig` if not indicated. |
| - ignore_index (`int`, *optional*, defaults to -100): |
| - The ignore index for the loss function. |
| image_token_index (`int`, *optional*, defaults to 32000): |
| The image token index to encode the image prompt. |
| video_token_index (`int`, *optional*, defaults to 32001): |
| @@ -88,7 +86,6 @@ def __init__( |
| self, |
| vision_config=None, |
| text_config=None, |
| - ignore_index=-100, |
| image_token_index=32000, |
| video_token_index=32001, |
| projector_hidden_act="gelu", |
| @@ -99,7 +96,6 @@ def __init__( |
| multimodal_projector_bias=True, |
| **kwargs, |
| ): |
| - self.ignore_index = ignore_index |
| self.image_token_index = image_token_index |
| self.video_token_index = video_token_index |
| self.projector_hidden_act = projector_hidden_act |
| |
| |
| |
| |
| @@ -28,6 +28,7 @@ |
| from ...utils import ( |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| + is_torchdynamo_compiling, |
| logging, |
| replace_return_docstrings, |
| ) |
| @@ -137,6 +138,8 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): |
| _supports_cache_class = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| + _supports_quantized_cache = True |
| + _supports_static_cache = True |
| |
| def _init_weights(self, module): |
| std = ( |
| @@ -276,92 +279,6 @@ def set_decoder(self, decoder): |
| def get_decoder(self): |
| return self.language_model.get_decoder() |
| |
| - def _merge_input_ids_with_visual_features( |
| - self, visual_features, inputs_embeds, input_ids, attention_mask, labels, num_frames=1 |
| - ): |
| - num_images, num_image_patches, embed_dim = visual_features.shape |
| - batch_size, sequence_length = input_ids.shape |
| - left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) |
| - special_vision_token = self.config.video_token_index if num_frames > 1 else self.config.image_token_index |
| - |
| - # 1. Create a mask to know where special image tokens are |
| - special_image_token_mask = input_ids == special_vision_token |
| - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) |
| - # Compute the maximum embed dimension |
| - max_seq_len = (num_special_image_tokens.max() * (num_image_patches * num_frames - 1)) + sequence_length |
| - batch_indices, non_image_indices = torch.where(input_ids != special_vision_token) |
| - |
| - # 2. Compute the positions where text should be written |
| - # Calculate new positions for text tokens in merged image-text sequence. |
| - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. |
| - # `torch.cumsum` computes how each image token shifts subsequent text token positions. |
| - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. |
| - new_token_positions = ( |
| - torch.cumsum((special_image_token_mask * (num_image_patches * num_frames - 1) + 1), dim=-1) - 1 |
| - ) |
| - nb_image_pad = max_seq_len - 1 - new_token_positions[:, -1] |
| - if left_padding: |
| - new_token_positions += nb_image_pad[:, None] # offset for left padding |
| - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] |
| - |
| - # 3. Create the full embedding, already padded to the maximum position |
| - # expand input ids so that the second "merge" with videos does not fail |
| - final_embedding = torch.zeros( |
| - batch_size, max_seq_len, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device |
| - ) |
| - final_attention_mask = torch.zeros( |
| - batch_size, max_seq_len, dtype=attention_mask.dtype, device=inputs_embeds.device |
| - ) |
| - final_input_ids = torch.full( |
| - (batch_size, max_seq_len), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device |
| - ) |
| - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually |
| - # set the corresponding tensors into their correct target device. |
| - target_device = inputs_embeds.device |
| - batch_indices, non_image_indices, text_to_overwrite = ( |
| - batch_indices.to(target_device), |
| - non_image_indices.to(target_device), |
| - text_to_overwrite.to(target_device), |
| - ) |
| - attention_mask = attention_mask.to(target_device) |
| - |
| - # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] |
| - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features |
| - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] |
| - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] |
| - final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices] |
| - if labels is not None: |
| - final_labels = torch.full( |
| - (batch_size, max_seq_len), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device |
| - ) |
| - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] |
| - else: |
| - final_labels = None |
| - |
| - # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling |
| - image_to_overwrite = torch.full((batch_size, max_seq_len), True, dtype=torch.bool, device=inputs_embeds.device) |
| - image_to_overwrite[batch_indices, text_to_overwrite] = False |
| - if left_padding: |
| - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) |
| - else: |
| - mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 |
| - padding_mask = mask <= new_token_positions[:, -1:].to(target_device) |
| - image_to_overwrite &= padding_mask |
| - |
| - if image_to_overwrite.sum() != visual_features.shape[:-1].numel(): |
| - visual_type = "videos" if num_frames == 8 else "images" |
| - num_images //= num_frames |
| - raise ValueError( |
| - f"The input provided to the model are wrong. The number of {visual_type} tokens is {torch.sum(special_image_token_mask)} while" |
| - f" the number of {visual_type} given to the model is {num_images}. This prevents correct indexing and breaks batch generation." |
| - ) |
| - |
| - final_embedding[image_to_overwrite] = visual_features.contiguous().reshape(-1, embed_dim).to(target_device) |
| - final_attention_mask |= image_to_overwrite |
| - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) |
| - |
| - return final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids |
| - |
| def get_image_features( |
| self, |
| pixel_values_images: torch.FloatTensor, |
| @@ -579,14 +496,14 @@ def forward( |
| vision_feature_layer=vision_feature_layer, |
| vision_feature_select_strategy=vision_feature_select_strategy, |
| ) |
| - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() |
| - n_image_features = image_features.shape[0] * image_features.shape[1] |
| - if n_image_tokens != n_image_features: |
| + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): |
| + n_image_tokens = (input_ids == self.config.image_token_index).sum() |
| + n_image_features = image_features.shape[0] * image_features.shape[1] |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
| - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
| |
| @@ -595,14 +512,14 @@ def forward( |
| pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer |
| ) |
| |
| - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() |
| - n_video_features = video_features.shape[0] * video_features.shape[1] |
| - if n_video_tokens != n_video_features: |
| + special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) |
| + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): |
| + n_video_tokens = (input_ids == self.config.video_token_index).sum() |
| + n_video_features = video_features.shape[0] * video_features.shape[1] |
| raise ValueError( |
| f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
| ) |
| - special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1) |
| - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) |
| |
| |
| |
| |
| |
| @@ -37,8 +37,6 @@ class VipLlavaConfig(PretrainedConfig): |
| Custom vision config or dict |
| text_config (`Union[AutoConfig, dict]`, *optional*): |
| The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`. |
| - ignore_index (`int`, *optional*, defaults to -100): |
| - The ignore index for the loss function. |
| image_token_index (`int`, *optional*, defaults to 32000): |
| The image token index to encode the image prompt. |
| projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): |
| @@ -78,7 +76,6 @@ def __init__( |
| self, |
| vision_config=None, |
| text_config=None, |
| - ignore_index=-100, |
| image_token_index=32000, |
| projector_hidden_act="gelu", |
| projector_layernorm_eps=1e-5, |
| @@ -86,7 +83,6 @@ def __init__( |
| image_seq_length=576, |
| **kwargs, |
| ): |
| - self.ignore_index = ignore_index |
| self.image_token_index = image_token_index |
| self.projector_hidden_act = projector_hidden_act |
| self.projector_layernorm_eps = projector_layernorm_eps |
| |
| |
| |
| |
| @@ -28,6 +28,7 @@ |
| from ...utils import ( |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| + is_torchdynamo_compiling, |
| logging, |
| replace_return_docstrings, |
| ) |
| @@ -137,6 +138,8 @@ class VipLlavaPreTrainedModel(PreTrainedModel): |
| _supports_cache_class = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| + _supports_quantized_cache = True |
| + _supports_static_cache = True |
| |
| def _init_weights(self, module): |
| # important: this ported version of VipLlava isn't meant for training from scratch - only |
| @@ -297,89 +300,6 @@ def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_lay |
| image_features = self.multi_modal_projector(image_features) |
| return image_features |
| |
| - def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): |
| - num_images, num_image_patches, embed_dim = image_features.shape |
| - batch_size, sequence_length = input_ids.shape |
| - left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) |
| - # 1. Create a mask to know where special image tokens are |
| - special_image_token_mask = input_ids == self.config.image_token_index |
| - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) |
| - # Compute the maximum embed dimension |
| - max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length |
| - batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) |
| - |
| - # 2. Compute the positions where text should be written |
| - # Calculate new positions for text tokens in merged image-text sequence. |
| - # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. |
| - # `torch.cumsum` computes how each image token shifts subsequent text token positions. |
| - # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. |
| - new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 |
| - nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] |
| - if left_padding: |
| - new_token_positions += nb_image_pad[:, None] # offset for left padding |
| - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] |
| - |
| - # 3. Create the full embedding, already padded to the maximum position |
| - final_embedding = torch.zeros( |
| - batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device |
| - ) |
| - final_attention_mask = torch.zeros( |
| - batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device |
| - ) |
| - if labels is not None: |
| - final_labels = torch.full( |
| - (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device |
| - ) |
| - # In case the Vision model or the Language model has been offloaded to CPU, we need to manually |
| - # set the corresponding tensors into their correct target device. |
| - target_device = inputs_embeds.device |
| - batch_indices, non_image_indices, text_to_overwrite = ( |
| - batch_indices.to(target_device), |
| - non_image_indices.to(target_device), |
| - text_to_overwrite.to(target_device), |
| - ) |
| - attention_mask = attention_mask.to(target_device) |
| - |
| - # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] |
| - # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features |
| - final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] |
| - final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] |
| - if labels is not None: |
| - final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] |
| - |
| - # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) |
| - image_to_overwrite = torch.full( |
| - (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device |
| - ) |
| - image_to_overwrite[batch_indices, text_to_overwrite] = False |
| - if left_padding: |
| - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) |
| - else: |
| - mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 |
| - padding_mask = mask <= new_token_positions[:, -1:].to(target_device) |
| - image_to_overwrite &= padding_mask |
| - |
| - if image_to_overwrite.sum() != image_features.shape[:-1].numel(): |
| - raise ValueError( |
| - f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" |
| - f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." |
| - ) |
| - |
| - final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) |
| - final_attention_mask |= image_to_overwrite |
| - position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) |
| - |
| - # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. |
| - batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) |
| - indices_to_mask = new_token_positions[batch_indices, pad_indices] |
| - |
| - final_embedding[batch_indices, indices_to_mask] = 0 |
| - |
| - if labels is None: |
| - final_labels = None |
| - |
| - return final_embedding, final_attention_mask, final_labels, position_ids |
| - |
| @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") |
| @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING) |
| @replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) |
| @@ -469,14 +389,14 @@ def forward( |
| pixel_values=pixel_values, vision_feature_layers=vision_feature_layers |
| ) |
| |
| - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() |
| - n_image_features = image_features.shape[0] * image_features.shape[1] |
| - if n_image_tokens != n_image_features: |
| + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| + if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): |
| + n_image_tokens = (input_ids == self.config.image_token_index).sum() |
| + n_image_features = image_features.shape[0] * image_features.shape[1] |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
| - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) |
| - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
| image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
| |
| |
| |
| |
| |
| @@ -1783,12 +1783,12 @@ def test_generate_from_inputs_embeds_with_static_cache(self): |
| model.config.use_cache = True |
| model.config.is_decoder = True |
| batch_size = input_ids.shape[0] |
| - max_length = 30 |
| + max_new_tokens = 10 |
| |
| # here we force to not stop at eos and go until max-length |
| model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1 |
| generation_kwargs = { |
| - "max_length": max_length, |
| + "max_new_tokens": max_new_tokens, |
| "cache_implementation": "static", |
| "return_dict_in_generate": True, # Required to return `past_key_values` |
| } |
| @@ -1811,10 +1811,11 @@ def test_generate_from_inputs_embeds_with_static_cache(self): |
| |
| # we should get `max_length - 1` in shape, not `max_length - embeds_length`. |
| # -1 because the last generated token isn't yet in the cache. |
| - cache_shape = (batch_size, num_key_value_heads, max_length - 1, head_dim) |
| - self.assertTrue(isinstance(outputs.past_key_values, StaticCache)) |
| - self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers) |
| - self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape) |
| + max_length = max_new_tokens + inputs_embeds.shape[1] - 1 |
| + cache_shape = [batch_size, num_key_value_heads, max_length, head_dim] |
| + self.assertIsInstance(outputs.past_key_values, StaticCache) |
| + self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers) |
| + self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape) |
| |
| @pytest.mark.generate |
| def test_generate_continue_from_past_key_values(self): |
| @@ -2022,7 +2023,7 @@ def test_generate_with_static_cache(self): |
| |
| config.is_decoder = True |
| batch_size = main_input.shape[0] |
| - seq_length = main_input.shape[-1] |
| + seq_length = self.model_tester.seq_length |
| max_new_tokens = 20 |
| |
| for dtype in (torch.float32, torch.float16): |
| @@ -2134,7 +2135,15 @@ def test_generate_compile_model_forward(self): |
| # compilation-specific setup |
| torch.compiler.reset() # prevent cached compilation from being used in the test |
| has_defined_cache_implementation = model.generation_config.cache_implementation is not None |
| - model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU) |
| + |
| + # BLIP is the only exception with custom generate which call `self.lm.generate()` |
| + # We should avoid such calls in all subsequent multimodal models and try to make `generate()` |
| + # compatible with multimodality |
| + if "blip" in model.__class__.__name__.lower(): |
| + model.language_model.generation_config.compile_config._compile_all_devices = True |
| + else: |
| + # force compilation (e.g. fast CI, CPU |
| + model.generation_config.compile_config._compile_all_devices = True |
| |
| generation_kwargs = { |
| "do_sample": False, |
| @@ -2175,7 +2184,14 @@ def test_generate_compile_model_forward(self): |
| ) |
| self.assertFalse(isinstance(decoder_cache, DynamicCache)) |
| self.assertTrue(decoder_cache.is_compileable) |
| - self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called |
| + |
| + # BLIP is the only exception with custom generate which call `self.lm.generate()` |
| + # We should avoid such calls in all subsequent multimodal models and try to make `generate()` |
| + # compatible with multimodality |
| + if "blip" in model.__class__.__name__.lower(): |
| + self.assertTrue(hasattr(model.language_model, "_compiled_call")) |
| + else: |
| + self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called |
| |
| for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs): |
| self._check_similar_generate_outputs(dynamic_result, compiled_result) |
| @@ -2198,9 +2214,19 @@ def test_generate_compilation_all_outputs(self): |
| # compilation-specific setup |
| torch.compiler.reset() # prevent cached compilation from being used in the test |
| has_defined_cache_implementation = model.generation_config.cache_implementation is not None |
| - model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU) |
| - if not has_defined_cache_implementation: |
| - model.generation_config.cache_implementation = "static" |
| + |
| + # BLIP is the only exception with custom generate which call `self.lm.generate()` |
| + # We should avoid such calls in all subsequent multimodal models and try to make `generate()` |
| + # compatible with multimodality |
| + if "blip" in model.__class__.__name__.lower(): |
| + model.language_model.generation_config.compile_config._compile_all_devices = True |
| + if not has_defined_cache_implementation: |
| + model.language_model.generation_config.cache_implementation = "static" |
| + else: |
| + # force compilation (e.g. fast CI, CPU) |
| + model.generation_config.compile_config._compile_all_devices = True |
| + if not has_defined_cache_implementation: |
| + model.generation_config.cache_implementation = "static" |
| |
| logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config) |
| output_generate = model.generate( |
| @@ -2218,8 +2244,10 @@ def test_generate_compilation_all_outputs(self): |
| **inputs_dict, |
| ) |
| |
| - # Sanity check: compilation has happened |
| - self.assertTrue(hasattr(model, "_compiled_call")) |
| + if "blip" in model.__class__.__name__.lower(): |
| + self.assertTrue(hasattr(model.language_model, "_compiled_call")) |
| + else: |
| + self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called |
| |
| if model.config.is_encoder_decoder: |
| self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) |
| |
| |
| |
| |
| @@ -286,10 +286,18 @@ def test_generate_from_inputs_embeds_0_greedy(self): |
| def test_generate_from_inputs_embeds_1_beam_search(self): |
| pass |
| |
| - @unittest.skip(reason="Unsupported") |
| + @unittest.skip(reason="Dynamic control flow due to MoE") |
| def test_generate_with_static_cache(self): |
| pass |
| |
| + @unittest.skip(reason="Dynamic control flow due to MoE") |
| + def test_generate_from_inputs_embeds_with_static_cache(self): |
| + pass |
| + |
| + @unittest.skip(reason="Dynamic control flow due to MoE") |
| + def test_generate_compile_model_forward(self): |
| + pass |
| + |
| |
| @require_torch |
| class AriaForConditionalGenerationIntegrationTest(unittest.TestCase): |
| |
| |
| |
| |
| @@ -816,6 +816,10 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature): |
| def test_generate_from_inputs_embeds(self, _, num_beams): |
| pass |
| |
| + @unittest.skip("BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present") |
| + def test_generate_from_inputs_embeds_with_static_cache(self): |
| + pass |
| + |
| |
| # this class is based on `T5ModelTester` found in tests/models/t5/test_modeling_t5.py |
| class Blip2TextModelTester: |
| |
| |
| |
| |
| @@ -386,10 +386,6 @@ def test_disk_offload_bin(self): |
| def test_cpu_offload(self): |
| pass |
| |
| - @unittest.skip("Doesn't work, tensors are not almost same") # TODO raushan fixme |
| - def test_custom_4d_attention_mask(self): |
| - pass |
| - |
| @unittest.skip("VQ-VAE module doesn't initialize weights properly") |
| def test_initialization(self): |
| pass |
| |
| |
| |
| |
| @@ -256,12 +256,6 @@ def test_generate_from_inputs_embeds_with_static_cache(self): |
| def test_past_key_values_format(self): |
| pass |
| |
| - @unittest.skip( |
| - reason="GotOcr2 needs a dynamic control flow to pass pixel values to the forward function only in the first generation step" |
| - ) |
| - def test_generate_compile_1_end_to_end(self): |
| - pass |
| - |
| @unittest.skip("FlashAttention only support fp16 and bf16 data type") |
| def test_flash_attn_2_fp32_ln(self): |
| pass |
| |
| |
| |
| |
| @@ -838,6 +838,14 @@ def test_contrastive_generate_low_memory(self): |
| def test_custom_4d_attention_mask(self): |
| pass |
| |
| + @unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs") |
| + def test_generate_with_static_cache(self): |
| + pass |
| + |
| + @unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs") |
| + def test_generate_compile_model_forward(self): |
| + pass |
| + |
| @unittest.skip(reason="We only test the model that takes in multiple images") |
| def test_model(self): |
| pass |
| |
| |
| |
| |
| @@ -530,6 +530,12 @@ def test_save_load_fast_init_from_base(self): |
| def test_save_load_fast_init_to_base(self): |
| pass |
| |
| + @unittest.skip( |
| + "InstructBLIP cannot generate only from input ids, and requires pixel values in all cases to be present" |
| + ) |
| + def test_generate_from_inputs_embeds_with_static_cache(self): |
| + pass |
| + |
| def test_forward_signature(self): |
| config, _ = self.model_tester.prepare_config_and_inputs_for_common() |
| |
| |
| |
| |
| |
| @@ -546,6 +546,12 @@ def test_save_load_fast_init_from_base(self): |
| def test_save_load_fast_init_to_base(self): |
| pass |
| |
| + @unittest.skip( |
| + "InstructBLIPVideo cannot generate only from input ids, and requires pixel values in all cases to be present" |
| + ) |
| + def test_generate_from_inputs_embeds_with_static_cache(self): |
| + pass |
| + |
| def test_forward_signature(self): |
| config, _ = self.model_tester.prepare_config_and_inputs_for_common() |
| |
| |
| |
| |
| |
| @@ -316,14 +316,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): |
| def test_training_gradient_checkpointing_use_reentrant_false(self): |
| pass |
| |
| - @unittest.skip(reason="Compile not yet supported because in LLava models") |
| - def test_sdpa_can_compile_dynamic(self): |
| - pass |
| - |
| - @unittest.skip(reason="Compile not yet supported because in LLava models") |
| - def test_sdpa_can_dispatch_on_flash(self): |
| - pass |
| - |
| @unittest.skip("FlashAttention only support fp16 and bf16 data type") |
| def test_flash_attn_2_fp32_ln(self): |
| pass |
| |
| |
| |
| |
| @@ -365,22 +365,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): |
| def test_training_gradient_checkpointing_use_reentrant_false(self): |
| pass |
| |
| - @unittest.skip(reason="Feedforward chunking is not yet supported") |
| - def test_feed_forward_chunking(self): |
| - pass |
| - |
| - @unittest.skip(reason="CPU offload is not yet supported") |
| - def test_cpu_offload(self): |
| - pass |
| - |
| - @unittest.skip(reason="Compile not yet supported because in LLava models") |
| - def test_sdpa_can_compile_dynamic(self): |
| - pass |
| - |
| - @unittest.skip(reason="Compile not yet supported because in LLava models") |
| - def test_sdpa_can_dispatch_on_flash(self): |
| - pass |
| - |
| @unittest.skip("FlashAttention only support fp16 and bf16 data type") |
| def test_flash_attn_2_fp32_ln(self): |
| pass |
| @@ -391,6 +375,10 @@ def test_flash_attn_2_fp32_ln(self): |
| def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): |
| pass |
| |
| + @unittest.skip("LLaVA Next has dynamic control flow in unpadding") |
| + def test_generate_compile_model_forward(self): |
| + pass |
| + |
| |
| @require_torch |
| class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): |
| |
| |
| |
| |
| @@ -382,26 +382,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): |
| def test_training_gradient_checkpointing_use_reentrant_false(self): |
| pass |
| |
| - @unittest.skip(reason="Feedforward chunking is not yet supported") |
| - def test_feed_forward_chunking(self): |
| - pass |
| - |
| - @unittest.skip(reason="CPU offload is not yet supported") |
| - def test_cpu_offload(self): |
| - pass |
| - |
| - @unittest.skip( |
| - reason="Compile not yet supported because in LLava models (https://github.com/huggingface/transformers/issues/29891)" |
| - ) |
| - def test_sdpa_can_compile_dynamic(self): |
| - pass |
| - |
| - @unittest.skip( |
| - reason="Compile not yet supported because in LLava models (https://github.com/huggingface/transformers/issues/29891)" |
| - ) |
| - def test_sdpa_can_dispatch_on_flash(self): |
| - pass |
| - |
| @unittest.skip("FlashAttention only support fp16 and bf16 data type") |
| def test_flash_attn_2_fp32_ln(self): |
| pass |
| @@ -412,6 +392,10 @@ def test_flash_attn_2_fp32_ln(self): |
| def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): |
| pass |
| |
| + @unittest.skip("LLaVA Next Video has dynamic control flow in unpadding") |
| + def test_generate_compile_model_forward(self): |
| + pass |
| + |
| |
| @require_torch |
| class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase): |
| |
| |
| |
| |
| @@ -346,6 +346,10 @@ def test_flash_attn_2_fp32_ln(self): |
| def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): |
| pass |
| |
| + @unittest.skip("LLaVA OneVision has dynamic control flow in unpadding") |
| + def test_generate_compile_model_forward(self): |
| + pass |
| + |
| |
| @require_torch |
| class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase): |
| |
| |
| |
| |
| @@ -540,7 +540,6 @@ def prepare_config_and_inputs_for_common(self): |
| "attention_mask": attention_mask, |
| "decoder_input_ids": decoder_input_ids, |
| "decoder_attention_mask": decoder_attention_mask, |
| - "use_cache": False, |
| } |
| return config, inputs_dict |
| |
| |
| |
| |
| |
| @@ -81,7 +81,7 @@ def __init__( |
| hidden_act="gelu", |
| hidden_dropout_prob=0.1, |
| attention_probs_dropout_prob=0.1, |
| - max_position_embeddings=20, |
| + max_position_embeddings=50, |
| eos_token_id=2, |
| pad_token_id=1, |
| bos_token_id=0, |
| @@ -89,7 +89,6 @@ def __init__( |
| num_labels=3, |
| word_embed_proj_dim=16, |
| type_sequence_label_size=2, |
| - attn_implementation="eager", |
| ): |
| self.parent = parent |
| self.batch_size = batch_size |
| @@ -113,7 +112,6 @@ def __init__( |
| self.type_sequence_label_size = type_sequence_label_size |
| self.word_embed_proj_dim = word_embed_proj_dim |
| self.is_encoder_decoder = False |
| - self.attn_implementation = attn_implementation |
| |
| def prepare_config_and_inputs(self): |
| input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp( |
| @@ -143,7 +141,6 @@ def get_config(self): |
| embed_dim=self.embed_dim, |
| is_encoder_decoder=False, |
| word_embed_proj_dim=self.word_embed_proj_dim, |
| - attn_implementation=self.attn_implementation, |
| ) |
| |
| def get_pipeline_config(self): |
| |
| |
| |
| |
| @@ -545,7 +545,6 @@ def prepare_config_and_inputs_for_common(self): |
| "attention_mask": attention_mask, |
| "decoder_input_ids": decoder_input_ids, |
| "decoder_attention_mask": decoder_attention_mask, |
| - "use_cache": False, |
| } |
| return config, inputs_dict |
| |
| |
| |
| |
| |
| @@ -226,14 +226,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): |
| def test_training_gradient_checkpointing_use_reentrant_false(self): |
| pass |
| |
| - @unittest.skip(reason="Pass because video-LLava requires `attention_mask is not None`") |
| - def test_sdpa_can_compile_dynamic(self): |
| - pass |
| - |
| - @unittest.skip(reason="Pass because video-LLava requires `attention_mask is not None`") |
| - def test_sdpa_can_dispatch_on_flash(self): |
| - pass |
| - |
| @unittest.skip("FlashAttention only support fp16 and bf16 data type") |
| def test_flash_attn_2_fp32_ln(self): |
| pass |
| |
| |
| |
| |
| @@ -306,14 +306,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): |
| def test_training_gradient_checkpointing_use_reentrant_false(self): |
| pass |
| |
| - @unittest.skip(reason="Compile not yet supported because it is not yet supported in LLava") |
| - def test_sdpa_can_compile_dynamic(self): |
| - pass |
| - |
| - @unittest.skip(reason="Compile not yet supported because in LLava models") |
| - def test_sdpa_can_dispatch_on_flash(self): |
| - pass |
| - |
| @unittest.skip("FlashAttention only support fp16 and bf16 data type") |
| def test_flash_attn_2_fp32_ln(self): |
| pass |
| |
| |
| |
| |
| @@ -4324,10 +4324,6 @@ def test_sdpa_can_dispatch_on_flash(self): |
| |
| config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| inputs_dict = self._prepare_for_class(inputs_dict, model_class) |
| - if config.model_type in ["llava", "llava_next", "vipllava", "video_llava"]: |
| - self.skipTest( |
| - reason="Llava-like models currently (transformers==4.39.1) requires an attention_mask input" |
| - ) |
| if config.model_type in ["paligemma"]: |
| self.skipTest( |
| "PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input" |
| @@ -4778,6 +4774,9 @@ def test_custom_4d_attention_mask(self): |
| model = model_class(config).to(device=torch_device, dtype=torch.float32) |
| set_model_for_less_flaky_test(model) |
| |
| + if "position_ids" not in inspect.signature(model.forward).parameters: |
| + continue # this model doesn't accept position ids as input |
| + |
| ( |
| input_ids, |
| position_ids, |
|
|