harness / diffs /35724.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py
index 84f0356cecb2..916631da7e8f 100644
--- a/src/transformers/models/blip_2/modeling_blip_2.py
+++ b/src/transformers/models/blip_2/modeling_blip_2.py
@@ -2016,6 +2016,9 @@ def forward(
class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
config_class = Blip2Config
main_input_name = "pixel_values"
+ _supports_cache_class = True
+ _supports_static_cache = True
+ _supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
def __init__(self, config: Blip2Config):
super().__init__(config)
diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py
index 0a9421409e25..65322e236ca0 100644
--- a/src/transformers/models/chameleon/modeling_chameleon.py
+++ b/src/transformers/models/chameleon/modeling_chameleon.py
@@ -1284,13 +1284,13 @@ def forward(
if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values)
- n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
- n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
- if n_image_tokens_in_text != n_image_features:
+ special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_tokens.numel():
+ n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
+ n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
)
- special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py
index 11353a0a990c..75144c65ecff 100644
--- a/src/transformers/models/cohere2/modeling_cohere2.py
+++ b/src/transformers/models/cohere2/modeling_cohere2.py
@@ -25,7 +25,7 @@
import torch.nn as nn
from ...activations import ACT2FN
-from ...cache_utils import Cache, HybridCache
+from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
@@ -701,7 +701,7 @@ def _update_causal_mask(
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
- if isinstance(past_key_values, HybridCache):
+ if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py
index e9fd43c49000..c977f873dc8c 100644
--- a/src/transformers/models/gemma2/modeling_gemma2.py
+++ b/src/transformers/models/gemma2/modeling_gemma2.py
@@ -25,7 +25,7 @@
import torch.nn as nn
from ...activations import ACT2FN
-from ...cache_utils import Cache, HybridCache
+from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
@@ -713,7 +713,7 @@ def _update_causal_mask(
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
- if isinstance(past_key_values, HybridCache):
+ if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py
index 4e3c8487c4d8..805e6ba0d2a3 100644
--- a/src/transformers/models/gemma2/modular_gemma2.py
+++ b/src/transformers/models/gemma2/modular_gemma2.py
@@ -20,7 +20,7 @@
import torch.utils.checkpoint
from ...activations import ACT2FN
-from ...cache_utils import Cache, HybridCache
+from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
@@ -550,7 +550,7 @@ def _update_causal_mask(
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
- if isinstance(past_key_values, HybridCache):
+ if isinstance(past_key_values, (HybridCache, StaticCache)):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
diff --git a/src/transformers/models/got_ocr2/configuration_got_ocr2.py b/src/transformers/models/got_ocr2/configuration_got_ocr2.py
index 480252ab1471..fb9a1fb68889 100644
--- a/src/transformers/models/got_ocr2/configuration_got_ocr2.py
+++ b/src/transformers/models/got_ocr2/configuration_got_ocr2.py
@@ -132,8 +132,6 @@ class GotOcr2Config(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
- ignore_index (`int`, *optional*, defaults to -100):
- The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 151859):
The image token index to encode the image prompt.
image_seq_length (`int`, *optional*, defaults to 576):
@@ -161,13 +159,11 @@ def __init__(
self,
vision_config=None,
text_config=None,
- ignore_index=-100,
image_token_index=151859,
image_seq_length=576,
pad_token_id=-1,
**kwargs,
):
- self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.image_seq_length = image_seq_length
self.pad_token_id = pad_token_id
diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py
index 957e05bea75a..86598ac08965 100644
--- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py
+++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py
@@ -594,6 +594,8 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of GotOcr2 isn't meant for training from scratch - only
@@ -748,89 +750,6 @@ def get_image_features(
image_outputs = self.vision_tower(pixel_values).last_hidden_state
return self.multi_modal_projector(image_outputs)
- def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
- num_images, num_image_patches, embed_dim = image_features.shape
- batch_size, sequence_length = input_ids.shape
- left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
- # 1. Create a mask to know where special image tokens are
- special_image_token_mask = input_ids == self.config.image_token_index
- num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
- # Compute the maximum embed dimension
- max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
- batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
-
- # 2. Compute the positions where text should be written
- # Calculate new positions for text tokens in merged image-text sequence.
- # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
- # `torch.cumsum` computes how each image token shifts subsequent text token positions.
- # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
- new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
- nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
- if left_padding:
- new_token_positions += nb_image_pad[:, None] # offset for left padding
- text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
-
- # 3. Create the full embedding, already padded to the maximum position
- final_embedding = torch.zeros(
- batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
- )
- final_attention_mask = torch.zeros(
- batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
- )
- if labels is not None:
- final_labels = torch.full(
- (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
- )
- # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
- # set the corresponding tensors into their correct target device.
- target_device = inputs_embeds.device
- batch_indices, non_image_indices, text_to_overwrite = (
- batch_indices.to(target_device),
- non_image_indices.to(target_device),
- text_to_overwrite.to(target_device),
- )
- attention_mask = attention_mask.to(target_device)
-
- # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
- # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
- final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
- final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
- if labels is not None:
- final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
-
- # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
- image_to_overwrite = torch.full(
- (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
- )
- image_to_overwrite[batch_indices, text_to_overwrite] = False
- if left_padding:
- image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
- else:
- mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
- padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
- image_to_overwrite &= padding_mask
-
- if image_to_overwrite.sum() != image_features.shape[:-1].numel():
- raise ValueError(
- f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
- f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
- )
-
- final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
- final_attention_mask |= image_to_overwrite
- position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
-
- # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
- batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
- indices_to_mask = new_token_positions[batch_indices, pad_indices]
-
- final_embedding[batch_indices, indices_to_mask] = 0
-
- if labels is None:
- final_labels = None
-
- return final_embedding, final_attention_mask, final_labels, position_ids
-
@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GotOcr2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py
index 899075683eb4..fff434ead2e9 100644
--- a/src/transformers/models/got_ocr2/modular_got_ocr2.py
+++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py
@@ -170,8 +170,6 @@ class GotOcr2Config(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
- ignore_index (`int`, *optional*, defaults to -100):
- The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 151859):
The image token index to encode the image prompt.
image_seq_length (`int`, *optional*, defaults to 576):
@@ -199,13 +197,11 @@ def __init__(
self,
vision_config=None,
text_config=None,
- ignore_index=-100,
image_token_index=151859,
image_seq_length=576,
pad_token_id=-1,
**kwargs,
):
- self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.image_seq_length = image_seq_length
self.pad_token_id = pad_token_id
diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
index d5153fb3f828..10b6efbc5943 100755
--- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
+++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
@@ -51,7 +51,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_quantized_cache = True
- _supports_static_cache = False # TODO (fix me): compilation fails due to a stide error?
+ _supports_static_cache = True
def _init_weights(self, module):
"""Initialize the weights"""
@@ -129,8 +129,8 @@ def forward(
cos, sin = position_embeddings
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
- query = torch.cat((query, query_pass), dim=-1)
- key = torch.cat((key, key_pass), dim=-1)
+ query = torch.cat((query, query_pass), dim=-1).contiguous()
+ key = torch.cat((key, key_pass), dim=-1).contiguous()
# Cache QKV values
if layer_past is not None:
diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py
index d877b8323b3b..546e78eac148 100644
--- a/src/transformers/models/granitemoe/modeling_granitemoe.py
+++ b/src/transformers/models/granitemoe/modeling_granitemoe.py
@@ -1108,6 +1108,7 @@ def forward(
router_logits=all_router_logits,
)
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
@@ -1116,13 +1117,8 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
- # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
- # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
- # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
-
if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and 0.0 in attention_mask:
+ if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
@@ -1143,7 +1139,6 @@ def _update_causal_mask(
return None
dtype, device = input_tensor.dtype, input_tensor.device
- min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
@@ -1154,25 +1149,17 @@ def _update_causal_mask(
else past_seen_tokens + sequence_length + 1
)
- if attention_mask is not None and attention_mask.dim() == 4:
- # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
- causal_mask = attention_mask
- else:
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
- )
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
@@ -1182,6 +1169,7 @@ def _update_causal_mask(
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py
index b705da44eba4..ea42d65b845c 100644
--- a/src/transformers/models/instructblip/modeling_instructblip.py
+++ b/src/transformers/models/instructblip/modeling_instructblip.py
@@ -1290,6 +1290,9 @@ def forward(
class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin):
config_class = InstructBlipConfig
main_input_name = "pixel_values"
+ _supports_cache_class = True
+ _supports_static_cache = True
+ _supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
def __init__(self, config: InstructBlipConfig):
super().__init__(config)
diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py
index dcf77863a149..5183a3c22faf 100644
--- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py
+++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py
@@ -1284,6 +1284,9 @@ def forward(
class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin):
config_class = InstructBlipVideoConfig
main_input_name = "pixel_values"
+ _supports_cache_class = True
+ _supports_static_cache = True
+ _supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
def __init__(self, config: InstructBlipVideoConfig):
super().__init__(config)
diff --git a/src/transformers/models/llava/configuration_llava.py b/src/transformers/models/llava/configuration_llava.py
index d2a3e9747b66..f476591b2eb6 100644
--- a/src/transformers/models/llava/configuration_llava.py
+++ b/src/transformers/models/llava/configuration_llava.py
@@ -37,8 +37,6 @@ class LlavaConfig(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
- ignore_index (`int`, *optional*, defaults to -100):
- The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
@@ -83,7 +81,6 @@ def __init__(
self,
vision_config=None,
text_config=None,
- ignore_index=-100,
image_token_index=32000,
projector_hidden_act="gelu",
vision_feature_select_strategy="default",
@@ -92,7 +89,6 @@ def __init__(
multimodal_projector_bias=True,
**kwargs,
):
- self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.image_seq_length = image_seq_length
diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py
index 36f212e76844..610ab417d92b 100644
--- a/src/transformers/models/llava/modeling_llava.py
+++ b/src/transformers/models/llava/modeling_llava.py
@@ -28,6 +28,7 @@
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@@ -136,6 +137,8 @@ class LlavaPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of Llava isn't meant for training from scratch - only
@@ -321,89 +324,6 @@ def get_image_features(
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
- def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
- num_images, num_image_patches, embed_dim = image_features.shape
- batch_size, sequence_length = input_ids.shape
- left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
- # 1. Create a mask to know where special image tokens are
- special_image_token_mask = input_ids == self.config.image_token_index
- num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
- # Compute the maximum embed dimension
- max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
- batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
-
- # 2. Compute the positions where text should be written
- # Calculate new positions for text tokens in merged image-text sequence.
- # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
- # `torch.cumsum` computes how each image token shifts subsequent text token positions.
- # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
- new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
- nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
- if left_padding:
- new_token_positions += nb_image_pad[:, None] # offset for left padding
- text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
-
- # 3. Create the full embedding, already padded to the maximum position
- final_embedding = torch.zeros(
- batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
- )
- final_attention_mask = torch.zeros(
- batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
- )
- if labels is not None:
- final_labels = torch.full(
- (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
- )
- # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
- # set the corresponding tensors into their correct target device.
- target_device = inputs_embeds.device
- batch_indices, non_image_indices, text_to_overwrite = (
- batch_indices.to(target_device),
- non_image_indices.to(target_device),
- text_to_overwrite.to(target_device),
- )
- attention_mask = attention_mask.to(target_device)
-
- # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
- # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
- final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
- final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
- if labels is not None:
- final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
-
- # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
- image_to_overwrite = torch.full(
- (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
- )
- image_to_overwrite[batch_indices, text_to_overwrite] = False
- if left_padding:
- image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
- else:
- mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
- padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
- image_to_overwrite &= padding_mask
-
- if image_to_overwrite.sum() != image_features.shape[:-1].numel():
- raise ValueError(
- f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
- f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
- )
-
- final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
- final_attention_mask |= image_to_overwrite
- position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
-
- # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
- batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
- indices_to_mask = new_token_positions[batch_indices, pad_indices]
-
- final_embedding[batch_indices, indices_to_mask] = 0
-
- if labels is None:
- final_labels = None
-
- return final_embedding, final_attention_mask, final_labels, position_ids
-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -499,14 +419,14 @@ def forward(
image_sizes=image_sizes,
)
- n_image_tokens = (input_ids == self.config.image_token_index).sum()
- n_image_features = image_features.shape[0] * image_features.shape[1]
- if n_image_tokens != n_image_features:
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
diff --git a/src/transformers/models/llava_next/configuration_llava_next.py b/src/transformers/models/llava_next/configuration_llava_next.py
index 2610275cedfd..3836dbf71cd2 100644
--- a/src/transformers/models/llava_next/configuration_llava_next.py
+++ b/src/transformers/models/llava_next/configuration_llava_next.py
@@ -36,8 +36,6 @@ class LlavaNextConfig(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
- ignore_index (`int`, *optional*, defaults to -100):
- The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
@@ -88,7 +86,6 @@ def __init__(
self,
vision_config=None,
text_config=None,
- ignore_index=-100,
image_token_index=32000,
projector_hidden_act="gelu",
vision_feature_select_strategy="default",
@@ -99,7 +96,6 @@ def __init__(
multimodal_projector_bias=True,
**kwargs,
):
- self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.image_seq_length = image_seq_length
diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py
index 06e1cc63940f..3cdf1b348404 100644
--- a/src/transformers/models/llava_next/modeling_llava_next.py
+++ b/src/transformers/models/llava_next/modeling_llava_next.py
@@ -31,6 +31,7 @@
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@@ -245,6 +246,8 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of LlavaNext isn't meant for training from scratch - only
@@ -405,245 +408,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.language_model.get_decoder()
- def _merge_input_ids_with_image_features(
- self,
- image_features,
- feature_lens,
- inputs_embeds,
- input_ids,
- attention_mask,
- position_ids=None,
- labels=None,
- image_token_index=None,
- ignore_index=-100,
- ):
- """
- Merge input_ids with with image features into final embeddings
-
- Args:
- image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`):
- All vision vectors of all images in the batch
- feature_lens (`torch.LongTensor` of shape `(num_images)`):
- The length of visual embeddings of each image as stacked in `image_features`
- inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
- Token embeddings before merging with visual embeddings
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Input_ids of tokens, possibly filled with image token
- attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Mask to avoid performing attention on padding token indices.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
- labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
- :abels need to be recalculated to support training (if provided)
- image_token_index (`int`, *optional*)
- Token id used to indicate the special "image" token. Defaults to `config.image_token_index`
- ignore_index (`int`, *optional*)
- Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100.
- Returns:
- final_embedding, final_attention_mask, position_ids, final_labels
-
- Explanation:
- each image has variable length embeddings, with length specified by feature_lens
- image_features is concatenation of all visual embed vectors
- task: fill each <image> with the correct number of visual embeddings
- Example:
- X (5 patches), Y (3 patches), Z (8)
- X, Y are in the same sequence (in-context learning)
- if right padding
- input_ids: [
- a b c d e f X g h i j k Y l m
- o p q r Z s t u v _ _ _ _ _ _
- ]
- input_ids should be: [
- a b c d e f X X X X X g h i j k Y Y Y l m
- o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
- ]
- labels should be: [
- a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
- o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
- ]
- elif left padding
- input_ids: [
- a b c d e f X g h i j k Y l m
- _ _ _ _ _ _ o p q r Z s t u v
- ]
- input_ids should be: [
- a b c d e f X X X X X g h i j k Y Y Y l m
- _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
- ]
- labels should be: [
- a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
- _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
- ]
- Edge cases:
- * If tokens are same but image token sizes are different, then cannot infer left or right padding
- ```python
- cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
- chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw)
- prompts = [
- "[INST] <image>\nWhat is shown in this image? [/INST]",
- "[INST] <image>\nWhat is shown in this image? [/INST]",
- ]
- inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda")
- chart_img has 2634 tokens, while cat_img has 2340 tokens
- ```
-
- input_ids: [
- a b c d X g h
- i j Y k l m n
- ]
- where X is 3 tokens while Y is 5, this mean after merge
- if left-padding (batched generation)
- input_ids should be: [
- _ _ a b c d X X X g h
- i j Y Y Y Y Y k l m n
- ]
- elif (right padding) (training)
- input_ids should be: [
- a b c d X X X g h _ _
- i j Y Y Y Y Y k l m n
- ]
- """
- image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
- ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
-
- if self.training and self.padding_side == "left":
- logger.warning_once(
- "Padding side is set to 'left' but the model is in training mode. For training "
- "it is recommended to set `model.padding_side='right' and `processor.tokenizer.padding_side='right'`. "
- "If that's intended, ignore this warning"
- )
- if not self.training and self.padding_side == "right":
- logger.warning_once(
- "Padding side is set to 'right' but the model is in inference mode. For correct "
- "generation results, please set `model.padding_side='left'` and `processor.tokenizer.padding_side='left'`. "
- "If that's intended, ignore this warning"
- )
-
- with torch.no_grad():
- # ! in llava 1.6, number of patches is variable
- num_images = feature_lens.size(0)
- num_image_features, embed_dim = image_features.shape
- if feature_lens.sum() != num_image_features:
- raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
- batch_size = input_ids.shape[0]
- _left_padding = torch.any(attention_mask[:, 0] == 0)
- _right_padding = torch.any(attention_mask[:, -1] == 0)
-
- left_padding = self.padding_side == "left"
- if batch_size > 1:
- if _left_padding and _right_padding:
- raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
- elif _right_padding and left_padding:
- left_padding = False
- elif _left_padding and not left_padding:
- left_padding = True
-
- # Whether to turn off right padding
- # 1. Create a mask to know where special image tokens are
- special_image_token_mask = input_ids == image_token_index
- # special_image_token_mask: [bsz, seqlen]
- num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
- # num_special_image_tokens: [bsz]
- # Reserve for padding of num_images
- total_num_special_image_tokens = torch.sum(special_image_token_mask)
- if total_num_special_image_tokens != num_images:
- raise ValueError(
- f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
- )
- # Compute the maximum embed dimension
- # max_image_feature_lens is max_feature_lens per batch
- feature_lens = feature_lens.to(input_ids.device)
- feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
- feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device)
- embed_sequence_lengths = (
- (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
- )
- max_embed_dim = embed_sequence_lengths.max()
-
- batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
- # 2. Compute the positions where text should be written
- # Calculate new positions for text tokens in merged image-text sequence.
- # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
- # `torch.cumsum` computes how each image token shifts subsequent text token positions.
- # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
- # ! instead of special_image_token_mask * (num_image_patches - 1)
- # special_image_token_mask * (num_feature_len - 1)
- special_image_token_mask = special_image_token_mask.long()
- special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
- new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
- if left_padding:
- # shift right token positions so that they are ending at the same number
- # the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
- new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
-
- text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
-
- # 3. Create the full embedding, already padded to the maximum position
- final_embedding = torch.zeros(
- batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
- )
- final_attention_mask = torch.zeros(
- batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
- )
- final_input_ids = torch.full(
- (batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device
- )
- # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
- # set the corresponding tensors into their correct target device.
- target_device = inputs_embeds.device
- batch_indices, non_image_indices, text_to_overwrite = (
- batch_indices.to(target_device),
- non_image_indices.to(target_device),
- text_to_overwrite.to(target_device),
- )
- attention_mask = attention_mask.to(target_device)
- input_ids = input_ids.to(target_device)
-
- # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
- # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
- final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
- final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
- final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices]
- final_labels = None
- if labels is not None:
- labels = labels.to(target_device)
- final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
- final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
-
- # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
- with torch.no_grad():
- image_to_overwrite = torch.full(
- (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
- )
- image_to_overwrite[batch_indices, text_to_overwrite] = False
- embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
- embed_indices = embed_indices.expand(batch_size, max_embed_dim)
- embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
-
- if left_padding:
- # exclude padding on the left
- max_embed_dim = max_embed_dim.to(target_device)
- val = (max_embed_dim - embed_indices) <= embed_seq_lens
- else:
- # exclude padding on the right
- val = embed_indices < embed_seq_lens
- image_to_overwrite &= val
-
- if image_to_overwrite.sum() != num_image_features:
- raise ValueError(
- f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
- f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
- f" the number of image given to the model is {num_images}. "
- f"This prevents correct indexing and breaks batch generation."
- )
- final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
- final_attention_mask |= image_to_overwrite
- position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
-
- return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
-
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@@ -875,14 +639,14 @@ def forward(
image_newline=self.image_newline,
)
- n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
- n_image_features = image_features.shape[0]
- if n_image_tokens != n_image_features:
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
+ n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
diff --git a/src/transformers/models/llava_next_video/configuration_llava_next_video.py b/src/transformers/models/llava_next_video/configuration_llava_next_video.py
index 6b85ebb4455e..01450f6b587c 100644
--- a/src/transformers/models/llava_next_video/configuration_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/configuration_llava_next_video.py
@@ -38,8 +38,6 @@ class LlavaNextVideoConfig(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
- ignore_index (`int`, *optional*, defaults to -100):
- The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32001):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
@@ -96,7 +94,6 @@ def __init__(
self,
vision_config=None,
text_config=None,
- ignore_index=-100,
image_token_index=32001,
projector_hidden_act="gelu",
multimodal_projector_bias=True,
@@ -116,7 +113,6 @@ def __init__(
self.spatial_pool_stride = spatial_pool_stride
self.image_seq_length = image_seq_length
self.video_seq_length = video_seq_length
- self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.multimodal_projector_bias = multimodal_projector_bias
diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py
index f62824947ddf..9ce88c541231 100644
--- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py
@@ -32,7 +32,13 @@
from ...image_processing_utils import select_best_resolution
from ...modeling_outputs import ModelOutput
from ...modeling_utils import PreTrainedModel
-from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_torchdynamo_compiling,
+ logging,
+ replace_return_docstrings,
+)
from ...utils.deprecation import deprecate_kwarg
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_llava_next_video import LlavaNextVideoConfig
@@ -153,6 +159,8 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of LlavaNextVideo isn't meant for training from scratch - only
@@ -440,245 +448,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.language_model.get_decoder()
- def _merge_input_ids_with_image_features(
- self,
- image_features,
- feature_lens,
- inputs_embeds,
- input_ids,
- attention_mask,
- position_ids=None,
- labels=None,
- image_token_index=None,
- ignore_index=-100,
- ):
- """
- Merge input_ids with with image features into final embeddings
-
- Args:
- image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`):
- All vision vectors of all images in the batch
- feature_lens (`torch.LongTensor` of shape `(num_images)`):
- The length of visual embeddings of each image as stacked in `image_features`
- inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
- Token embeddings before merging with visual embeddings
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Input_ids of tokens, possibly filled with image token
- attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Mask to avoid performing attention on padding token indices.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
- labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
- :abels need to be recalculated to support training (if provided)
- image_token_index (`int`, *optional*)
- Token id used to indicate the special "image" token. Defaults to `config.image_token_index`
- ignore_index (`int`, *optional*)
- Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100.
- Returns:
- final_embedding, final_attention_mask, position_ids, final_labels
-
- Explanation:
- each image has variable length embeddings, with length specified by feature_lens
- image_features is concatenation of all visual embed vectors
- task: fill each <image> with the correct number of visual embeddings
- Example:
- X (5 patches), Y (3 patches), Z (8)
- X, Y are in the same sequence (in-context learning)
- if right padding
- input_ids: [
- a b c d e f X g h i j k Y l m
- o p q r Z s t u v _ _ _ _ _ _
- ]
- input_ids should be: [
- a b c d e f X X X X X g h i j k Y Y Y l m
- o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
- ]
- labels should be: [
- a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
- o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
- ]
- elif left padding
- input_ids: [
- a b c d e f X g h i j k Y l m
- _ _ _ _ _ _ o p q r Z s t u v
- ]
- input_ids should be: [
- a b c d e f X X X X X g h i j k Y Y Y l m
- _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
- ]
- labels should be: [
- a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
- _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
- ]
- Edge cases:
- * If tokens are same but image token sizes are different, then cannot infer left or right padding
- ```python
- cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
- chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw)
- prompts = [
- "[INST] <image>\nWhat is shown in this image? [/INST]",
- "[INST] <image>\nWhat is shown in this image? [/INST]",
- ]
- inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda")
- chart_img has 2634 tokens, while cat_img has 2340 tokens
- ```
-
- input_ids: [
- a b c d X g h
- i j Y k l m n
- ]
- where X is 3 tokens while Y is 5, this mean after merge
- if left-padding (batched generation)
- input_ids should be: [
- _ _ a b c d X X X g h
- i j Y Y Y Y Y k l m n
- ]
- elif (right padding) (training)
- input_ids should be: [
- a b c d X X X g h _ _
- i j Y Y Y Y Y k l m n
- ]
- """
- image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
- ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
-
- if self.training and self.padding_side == "left":
- logger.warning_once(
- "Padding side is set to 'left' but the model is in training mode. For training "
- "it is recommended to set `model.padding_side='right' and `processor.tokenizer.padding_side='right'`. "
- "If that's intended, ignore this warning"
- )
- if not self.training and self.padding_side == "right":
- logger.warning_once(
- "Padding side is set to 'right' but the model is in inference mode. For correct "
- "generation results, please set `model.padding_side='left'` and `processor.tokenizer.padding_side='left'`. "
- "If that's intended, ignore this warning"
- )
-
- with torch.no_grad():
- # ! in llava 1.6, number of patches is variable
- num_images = feature_lens.size(0)
- num_image_features, embed_dim = image_features.shape
- if feature_lens.sum() != num_image_features:
- raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
- batch_size = input_ids.shape[0]
- _left_padding = torch.any(attention_mask[:, 0] == 0)
- _right_padding = torch.any(attention_mask[:, -1] == 0)
-
- left_padding = self.padding_side == "left"
- if batch_size > 1:
- if _left_padding and _right_padding:
- raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
- elif _right_padding and left_padding:
- left_padding = False
- elif _left_padding and not left_padding:
- left_padding = True
-
- # Whether to turn off right padding
- # 1. Create a mask to know where special image tokens are
- special_image_token_mask = input_ids == image_token_index
- # special_image_token_mask: [bsz, seqlen]
- num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
- # num_special_image_tokens: [bsz]
- # Reserve for padding of num_images
- total_num_special_image_tokens = torch.sum(special_image_token_mask)
- if total_num_special_image_tokens != num_images:
- raise ValueError(
- f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
- )
- # Compute the maximum embed dimension
- # max_image_feature_lens is max_feature_lens per batch
- feature_lens = feature_lens.to(input_ids.device)
- feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
- feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device)
- embed_sequence_lengths = (
- (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
- )
- max_embed_dim = embed_sequence_lengths.max()
-
- batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
- # 2. Compute the positions where text should be written
- # Calculate new positions for text tokens in merged image-text sequence.
- # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
- # `torch.cumsum` computes how each image token shifts subsequent text token positions.
- # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
- # ! instead of special_image_token_mask * (num_image_patches - 1)
- # special_image_token_mask * (num_feature_len - 1)
- special_image_token_mask = special_image_token_mask.long()
- special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
- new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
- if left_padding:
- # shift right token positions so that they are ending at the same number
- # the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
- new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
-
- text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
-
- # 3. Create the full embedding, already padded to the maximum position
- final_embedding = torch.zeros(
- batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
- )
- final_attention_mask = torch.zeros(
- batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
- )
- final_input_ids = torch.full(
- (batch_size, max_embed_dim), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device
- )
- # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
- # set the corresponding tensors into their correct target device.
- target_device = inputs_embeds.device
- batch_indices, non_image_indices, text_to_overwrite = (
- batch_indices.to(target_device),
- non_image_indices.to(target_device),
- text_to_overwrite.to(target_device),
- )
- attention_mask = attention_mask.to(target_device)
- input_ids = input_ids.to(target_device)
-
- # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
- # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
- final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
- final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
- final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices]
- final_labels = None
- if labels is not None:
- labels = labels.to(target_device)
- final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
- final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
-
- # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
- with torch.no_grad():
- image_to_overwrite = torch.full(
- (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
- )
- image_to_overwrite[batch_indices, text_to_overwrite] = False
- embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
- embed_indices = embed_indices.expand(batch_size, max_embed_dim)
- embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
-
- if left_padding:
- # exclude padding on the left
- max_embed_dim = max_embed_dim.to(target_device)
- val = (max_embed_dim - embed_indices) <= embed_seq_lens
- else:
- # exclude padding on the right
- val = embed_indices < embed_seq_lens
- image_to_overwrite &= val
-
- if image_to_overwrite.sum() != num_image_features:
- raise ValueError(
- f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
- f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
- f" the number of image given to the model is {num_images}. "
- f"This prevents correct indexing and breaks batch generation."
- )
- final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
- final_attention_mask |= image_to_overwrite
- position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
-
- return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids
-
def pack_image_features(self, image_features, image_sizes, vision_feature_select_strategy, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
@@ -948,14 +717,14 @@ def forward(
image_newline=self.image_newline,
)
- n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
- n_image_features = image_features.shape[0]
- if n_image_tokens != n_image_features:
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
+ n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
@@ -970,14 +739,14 @@ def forward(
video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
- n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
- n_video_features = video_features.shape[0]
- if n_video_tokens != n_video_features:
+ special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
+ n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
+ n_video_features = video_features.shape[0]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
- special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py
index b2e06c337c1b..8769f8db4131 100644
--- a/src/transformers/models/llava_next_video/modular_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py
@@ -30,6 +30,7 @@
from ...configuration_utils import PretrainedConfig
from ...utils import (
+ is_torchdynamo_compiling,
logging,
)
from ..auto import CONFIG_MAPPING, AutoConfig
@@ -52,8 +53,6 @@ class LlavaNextVideoConfig(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
- ignore_index (`int`, *optional*, defaults to -100):
- The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32001):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
@@ -110,7 +109,6 @@ def __init__(
self,
vision_config=None,
text_config=None,
- ignore_index=-100,
image_token_index=32001,
projector_hidden_act="gelu",
multimodal_projector_bias=True,
@@ -130,7 +128,6 @@ def __init__(
self.spatial_pool_stride = spatial_pool_stride
self.image_seq_length = image_seq_length
self.video_seq_length = video_seq_length
- self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.multimodal_projector_bias = multimodal_projector_bias
@@ -479,14 +476,14 @@ def forward(
image_newline=self.image_newline,
)
- n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
- n_image_features = image_features.shape[0]
- if n_image_tokens != n_image_features:
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
+ n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
@@ -501,14 +498,14 @@ def forward(
video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
- n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
- n_video_features = video_features.shape[0]
- if n_video_tokens != n_video_features:
+ special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
+ n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
+ n_video_features = video_features.shape[0]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
- special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py
index ed584bda7f5d..e86ce394e13d 100644
--- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py
+++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py
@@ -30,6 +30,7 @@
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_start_docstrings,
+ is_torchdynamo_compiling,
logging,
)
from ...utils.deprecation import deprecate_kwarg
@@ -250,7 +251,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True
- _supports_static_cache = False # Qwen2 doesn't but llava has no reasons to not support
+ _supports_static_cache = True
_supports_quantized_cache = True
_supports_sdpa = True
@@ -712,19 +713,15 @@ def forward(
image_newline=self.image_newline,
vision_aspect_ratio=vision_aspect_ratio,
)
- n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
- n_image_features = image_features.shape[0]
- if n_image_tokens != n_image_features:
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
+ n_image_features = image_features.shape[0]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
- special_image_mask = (
- (input_ids == self.config.image_token_index)
- .unsqueeze(-1)
- .expand_as(inputs_embeds)
- .to(inputs_embeds.device)
- )
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
@@ -741,18 +738,14 @@ def forward(
video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1)
- n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
- n_video_features = video_features.shape[0]
- if n_video_tokens != n_video_features:
+ special_video_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
+ special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
+ n_video_tokens = (input_ids == self.config.video_token_index).sum()
+ n_video_features = video_features.shape[0]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
- special_video_mask = (
- (input_ids == self.config.video_token_index)
- .unsqueeze(-1)
- .expand_as(inputs_embeds)
- .to(inputs_embeds.device)
- )
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features)
diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py
index 1969acf2f5b1..f1f1ef1821c7 100644
--- a/src/transformers/models/opt/modeling_opt.py
+++ b/src/transformers/models/opt/modeling_opt.py
@@ -22,10 +22,10 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
- _prepare_4d_causal_attention_mask,
- _prepare_4d_causal_attention_mask_for_sdpa,
+ AttentionMaskConverter,
)
from ...modeling_outputs import (
BaseModelOutputWithPast,
@@ -98,6 +98,7 @@ class OPTAttention(nn.Module):
def __init__(
self,
config: OPTConfig,
+ layer_idx: int = None,
**kwargs,
):
super().__init__()
@@ -106,6 +107,13 @@ def __init__(
self.num_heads = config.num_attention_heads
self.dropout = config.attention_dropout
self.enable_bias = config.enable_bias
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
self.head_dim = self.embed_dim // self.num_heads
self.is_causal = True
@@ -122,9 +130,6 @@ def __init__(
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -134,52 +139,33 @@ def forward(
output_attentions: bool = False,
# isn't needed in normal attention, but needed in flash attention so to keep the signature same
position_ids: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
- # get key, value proj
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
- past_key_value = (key_states, value_states)
+ query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.view(*proj_shape)
- value_states = value_states.view(*proj_shape)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+ key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
+ attn_weights = torch.matmul(query_states, key_states.transpose(3, 2))
if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = torch.max(
- attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
- )
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
- if attn_weights.dtype == torch.float16:
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
- else:
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
@@ -187,39 +173,19 @@ def forward(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_output = torch.matmul(attn_probs, value_states)
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.transpose(1, 2).contiguous()
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
+ return attn_output, attn_probs, past_key_value
class OptFlashAttention2(OPTAttention):
@@ -245,33 +211,33 @@ def forward(
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
position_ids: Optional[torch.Tensor] = None,
+ cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
- bsz, _, _ = hidden_states.size()
- # get query proj
- query_states = self.q_proj(hidden_states)
- # get key, value proj
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ bsz, query_length, _ = hidden_states.size()
- past_key_value = (key_states, value_states)
+ query_states = self.q_proj(hidden_states)
+ query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
- query_length = query_states.shape[1]
- tgt_len = key_states.shape[-2]
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+ key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- # Flash attention requires the input to have the shape
- # batch_size x seq_length x head_dim x hidden_dim
- query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim)
- key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
- value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
attn_dropout = self.dropout if self.training else 0.0
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
@@ -331,6 +297,7 @@ def forward(
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
position_ids: Optional[torch.Tensor] = None,
+ cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions or layer_head_mask is not None:
logger.warning_once(
@@ -344,24 +311,24 @@ def forward(
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
- ) # TODO after merge add position_ids=position_ids
+ cache_position=cache_position,
+ )
bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states) * self.scaling
- query_states = self._shape(query_states, -1, bsz)
-
- # get key, value proj
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ query_states = self.q_proj(hidden_states)
+ query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- past_key_value = (key_states, value_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+ key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- # shape now is (bsz, num_heads, seq_len, head_dim), all are continuous
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
causal_mask = attention_mask
if attention_mask is not None:
@@ -378,10 +345,6 @@ def forward(
attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
- # this model uses the scaling factor in the query projection for some reason, but not in Q@K^T
- # so we need to scale to remove scaling in SDPA to have similar results with eager.
- # Maybe needs a change in the model to remove scaling in query projection
- scale=1.0,
)
attn_output = attn_output.transpose(1, 2).contiguous()
@@ -399,11 +362,11 @@ def forward(
class OPTDecoderLayer(nn.Module):
- def __init__(self, config: OPTConfig):
+ def __init__(self, config: OPTConfig, layer_idx: int = None):
super().__init__()
self.embed_dim = config.hidden_size
- self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config)
+ self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
@@ -425,6 +388,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
position_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@@ -440,6 +404,8 @@ def forward(
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence..
"""
residual = hidden_states
@@ -456,6 +422,7 @@ def forward(
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
+ cache_position=cache_position,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
@@ -524,6 +491,9 @@ class OPTPreTrainedModel(PreTrainedModel):
_no_split_modules = ["OPTDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
def _init_weights(self, module):
std = self.config.init_std
@@ -601,6 +571,10 @@ def _init_weights(self, module):
config.n_positions - 1]`. for padding use -1.
[What are position IDs?](../glossary#position-ids)
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
"""
@@ -643,9 +617,7 @@ def __init__(self, config: OPTConfig):
else:
self.final_layer_norm = None
- self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
- self._use_sdpa = config._attn_implementation == "sdpa"
+ self.layers = nn.ModuleList([OPTDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@@ -657,48 +629,130 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.embed_tokens = value
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
- inputs_embeds: torch.Tensor,
- input_shape: Tuple[int, int],
- past_key_values_length: int,
- attention_mask: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool,
):
- """
- Updates the causal mask for the decoder.
- """
- batch_size, seq_length = input_shape
- mask_seq_length = past_key_values_length + seq_length
- if self._use_flash_attention_2:
- # 2d mask is passed through the layers
- causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- attention_mask = (
- torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
- if attention_mask is None
- else attention_mask
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_static_cache = isinstance(past_key_values, StaticCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ sequence_length = input_tensor.shape[1]
+ if using_static_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
)
- return causal_attention_mask, attention_mask
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
- if attention_mask is None:
- attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
- elif attention_mask.shape[1] != mask_seq_length:
- raise ValueError(
- f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
- f"{mask_seq_length} (sum of the lengths of current and past inputs)"
- )
- if self._use_sdpa and not output_attentions and head_mask is None:
- causal_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
- attention_mask, input_shape, inputs_embeds, past_key_values_length
- )
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to plcae the 4D attention mask on.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
else:
- causal_attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask, input_shape, inputs_embeds, past_key_values_length
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
- return causal_attention_mask, attention_mask
+ return causal_mask
def forward(
self,
@@ -712,6 +766,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
Args:
@@ -764,6 +819,10 @@ def forward(
config.n_positions - 1]`. for padding use -1.
[What are position IDs?](../glossary#position-ids)
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -773,51 +832,65 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # retrieve input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
- elif input_ids is not None:
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if input_ids is not None:
+ input_ids = input_ids.view(-1, input_ids.shape[-1])
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ if past_key_values is None:
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
+ "You should pass an instance of `DynamicCache` instead, e.g. "
+ "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
+ )
+
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if attention_mask is None:
+ seq_length = past_seen_tokens + inputs_embeds.shape[1]
+ attention_mask = torch.ones(inputs_embeds.shape[0], seq_length, device=inputs_embeds.device)
- causal_attention_mask, attention_mask = self._update_causal_mask(
- inputs_embeds, input_shape, past_key_values_length, attention_mask, head_mask, output_attentions
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
- # embed positions
+ # embed positions
if position_ids is None:
+ # position_ids = cache_position.unsqueeze(0)
position_ids = torch.cumsum(attention_mask, dim=1)
position_ids = (position_ids * attention_mask - 1).long()
- # cut positions if `past_key_values_length` is > 0
- position_ids = position_ids[:, past_key_values_length:]
+ # cut positions if `past_seen_tokens` is > 0
+ position_ids = position_ids[:, past_seen_tokens:]
- pos_embeds = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
+ pos_embeds = self.embed_positions(attention_mask, past_seen_tokens, position_ids=position_ids)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds.to(inputs_embeds.device)
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
- next_decoder_cache = () if use_cache else None
+ next_decoder_cache = None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
@@ -838,34 +911,34 @@ def forward(
if dropout_probability < self.layerdrop:
continue
- past_key_value = past_key_values[idx] if past_key_values is not None else None
-
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
- causal_attention_mask,
+ causal_mask,
head_mask[idx] if head_mask is not None else None,
None,
output_attentions,
use_cache,
position_ids,
+ cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
- attention_mask=causal_attention_mask,
+ attention_mask=causal_mask,
position_ids=position_ids,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
- past_key_value=past_key_value,
+ past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
+ cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -881,6 +954,9 @@ def forward(
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
@@ -930,6 +1006,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -950,6 +1027,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ cache_position=cache_position,
)
if not return_dict:
@@ -1008,6 +1086,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
@@ -1069,6 +1148,10 @@ def forward(
config.n_positions - 1]`. for padding use -1.
[What are position IDs?](../glossary#position-ids)
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
Returns:
@@ -1107,6 +1190,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ cache_position=cache_position,
)
logits = self.lm_head(outputs[0]).contiguous()
diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py
index 9172b98c069e..35ad047a00dd 100644
--- a/src/transformers/models/paligemma/modeling_paligemma.py
+++ b/src/transformers/models/paligemma/modeling_paligemma.py
@@ -29,6 +29,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@@ -508,7 +509,7 @@ def forward(
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
- if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = torch.sum(input_ids == self.config.image_token_index)
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
diff --git a/src/transformers/models/video_llava/configuration_video_llava.py b/src/transformers/models/video_llava/configuration_video_llava.py
index becd20040332..e761481d8259 100644
--- a/src/transformers/models/video_llava/configuration_video_llava.py
+++ b/src/transformers/models/video_llava/configuration_video_llava.py
@@ -38,8 +38,6 @@ class VideoLlavaConfig(PretrainedConfig):
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
Defaults to `LlamaConfig` if not indicated.
- ignore_index (`int`, *optional*, defaults to -100):
- The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
video_token_index (`int`, *optional*, defaults to 32001):
@@ -88,7 +86,6 @@ def __init__(
self,
vision_config=None,
text_config=None,
- ignore_index=-100,
image_token_index=32000,
video_token_index=32001,
projector_hidden_act="gelu",
@@ -99,7 +96,6 @@ def __init__(
multimodal_projector_bias=True,
**kwargs,
):
- self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.video_token_index = video_token_index
self.projector_hidden_act = projector_hidden_act
diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py
index d8da974b9862..ba4de6537442 100644
--- a/src/transformers/models/video_llava/modeling_video_llava.py
+++ b/src/transformers/models/video_llava/modeling_video_llava.py
@@ -28,6 +28,7 @@
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@@ -137,6 +138,8 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
def _init_weights(self, module):
std = (
@@ -276,92 +279,6 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.language_model.get_decoder()
- def _merge_input_ids_with_visual_features(
- self, visual_features, inputs_embeds, input_ids, attention_mask, labels, num_frames=1
- ):
- num_images, num_image_patches, embed_dim = visual_features.shape
- batch_size, sequence_length = input_ids.shape
- left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
- special_vision_token = self.config.video_token_index if num_frames > 1 else self.config.image_token_index
-
- # 1. Create a mask to know where special image tokens are
- special_image_token_mask = input_ids == special_vision_token
- num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
- # Compute the maximum embed dimension
- max_seq_len = (num_special_image_tokens.max() * (num_image_patches * num_frames - 1)) + sequence_length
- batch_indices, non_image_indices = torch.where(input_ids != special_vision_token)
-
- # 2. Compute the positions where text should be written
- # Calculate new positions for text tokens in merged image-text sequence.
- # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
- # `torch.cumsum` computes how each image token shifts subsequent text token positions.
- # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
- new_token_positions = (
- torch.cumsum((special_image_token_mask * (num_image_patches * num_frames - 1) + 1), dim=-1) - 1
- )
- nb_image_pad = max_seq_len - 1 - new_token_positions[:, -1]
- if left_padding:
- new_token_positions += nb_image_pad[:, None] # offset for left padding
- text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
-
- # 3. Create the full embedding, already padded to the maximum position
- # expand input ids so that the second "merge" with videos does not fail
- final_embedding = torch.zeros(
- batch_size, max_seq_len, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
- )
- final_attention_mask = torch.zeros(
- batch_size, max_seq_len, dtype=attention_mask.dtype, device=inputs_embeds.device
- )
- final_input_ids = torch.full(
- (batch_size, max_seq_len), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device
- )
- # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
- # set the corresponding tensors into their correct target device.
- target_device = inputs_embeds.device
- batch_indices, non_image_indices, text_to_overwrite = (
- batch_indices.to(target_device),
- non_image_indices.to(target_device),
- text_to_overwrite.to(target_device),
- )
- attention_mask = attention_mask.to(target_device)
-
- # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
- # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
- final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
- final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
- final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices]
- if labels is not None:
- final_labels = torch.full(
- (batch_size, max_seq_len), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
- )
- final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
- else:
- final_labels = None
-
- # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
- image_to_overwrite = torch.full((batch_size, max_seq_len), True, dtype=torch.bool, device=inputs_embeds.device)
- image_to_overwrite[batch_indices, text_to_overwrite] = False
- if left_padding:
- image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
- else:
- mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
- padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
- image_to_overwrite &= padding_mask
-
- if image_to_overwrite.sum() != visual_features.shape[:-1].numel():
- visual_type = "videos" if num_frames == 8 else "images"
- num_images //= num_frames
- raise ValueError(
- f"The input provided to the model are wrong. The number of {visual_type} tokens is {torch.sum(special_image_token_mask)} while"
- f" the number of {visual_type} given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
- )
-
- final_embedding[image_to_overwrite] = visual_features.contiguous().reshape(-1, embed_dim).to(target_device)
- final_attention_mask |= image_to_overwrite
- position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
-
- return final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids
-
def get_image_features(
self,
pixel_values_images: torch.FloatTensor,
@@ -579,14 +496,14 @@ def forward(
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)
- n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
- n_image_features = image_features.shape[0] * image_features.shape[1]
- if n_image_tokens != n_image_features:
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
@@ -595,14 +512,14 @@ def forward(
pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer
)
- n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
- n_video_features = video_features.shape[0] * video_features.shape[1]
- if n_video_tokens != n_video_features:
+ special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel():
+ n_video_tokens = (input_ids == self.config.video_token_index).sum()
+ n_video_features = video_features.shape[0] * video_features.shape[1]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
- special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
diff --git a/src/transformers/models/vipllava/configuration_vipllava.py b/src/transformers/models/vipllava/configuration_vipllava.py
index 94d890c4b84e..ac24cce24129 100644
--- a/src/transformers/models/vipllava/configuration_vipllava.py
+++ b/src/transformers/models/vipllava/configuration_vipllava.py
@@ -37,8 +37,6 @@ class VipLlavaConfig(PretrainedConfig):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
- ignore_index (`int`, *optional*, defaults to -100):
- The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
@@ -78,7 +76,6 @@ def __init__(
self,
vision_config=None,
text_config=None,
- ignore_index=-100,
image_token_index=32000,
projector_hidden_act="gelu",
projector_layernorm_eps=1e-5,
@@ -86,7 +83,6 @@ def __init__(
image_seq_length=576,
**kwargs,
):
- self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.projector_layernorm_eps = projector_layernorm_eps
diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py
index 71201db2098e..ef4b3bff3958 100644
--- a/src/transformers/models/vipllava/modeling_vipllava.py
+++ b/src/transformers/models/vipllava/modeling_vipllava.py
@@ -28,6 +28,7 @@
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
+ is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
@@ -137,6 +138,8 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
def _init_weights(self, module):
# important: this ported version of VipLlava isn't meant for training from scratch - only
@@ -297,89 +300,6 @@ def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_lay
image_features = self.multi_modal_projector(image_features)
return image_features
- def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
- num_images, num_image_patches, embed_dim = image_features.shape
- batch_size, sequence_length = input_ids.shape
- left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
- # 1. Create a mask to know where special image tokens are
- special_image_token_mask = input_ids == self.config.image_token_index
- num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
- # Compute the maximum embed dimension
- max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
- batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
-
- # 2. Compute the positions where text should be written
- # Calculate new positions for text tokens in merged image-text sequence.
- # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
- # `torch.cumsum` computes how each image token shifts subsequent text token positions.
- # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
- new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
- nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
- if left_padding:
- new_token_positions += nb_image_pad[:, None] # offset for left padding
- text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
-
- # 3. Create the full embedding, already padded to the maximum position
- final_embedding = torch.zeros(
- batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
- )
- final_attention_mask = torch.zeros(
- batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
- )
- if labels is not None:
- final_labels = torch.full(
- (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
- )
- # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
- # set the corresponding tensors into their correct target device.
- target_device = inputs_embeds.device
- batch_indices, non_image_indices, text_to_overwrite = (
- batch_indices.to(target_device),
- non_image_indices.to(target_device),
- text_to_overwrite.to(target_device),
- )
- attention_mask = attention_mask.to(target_device)
-
- # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
- # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
- final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
- final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
- if labels is not None:
- final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
-
- # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
- image_to_overwrite = torch.full(
- (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
- )
- image_to_overwrite[batch_indices, text_to_overwrite] = False
- if left_padding:
- image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
- else:
- mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
- padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
- image_to_overwrite &= padding_mask
-
- if image_to_overwrite.sum() != image_features.shape[:-1].numel():
- raise ValueError(
- f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
- f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
- )
-
- final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
- final_attention_mask |= image_to_overwrite
- position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
-
- # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
- batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
- indices_to_mask = new_token_positions[batch_indices, pad_indices]
-
- final_embedding[batch_indices, indices_to_mask] = 0
-
- if labels is None:
- final_labels = None
-
- return final_embedding, final_attention_mask, final_labels, position_ids
-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -469,14 +389,14 @@ def forward(
pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
)
- n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
- n_image_features = image_features.shape[0] * image_features.shape[1]
- if n_image_tokens != n_image_features:
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ n_image_tokens = (input_ids == self.config.image_token_index).sum()
+ n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py
index ce31cc844f19..3b9700dc20c9 100644
--- a/tests/generation/test_utils.py
+++ b/tests/generation/test_utils.py
@@ -1783,12 +1783,12 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
model.config.use_cache = True
model.config.is_decoder = True
batch_size = input_ids.shape[0]
- max_length = 30
+ max_new_tokens = 10
# here we force to not stop at eos and go until max-length
model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1
generation_kwargs = {
- "max_length": max_length,
+ "max_new_tokens": max_new_tokens,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}
@@ -1811,10 +1811,11 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
# we should get `max_length - 1` in shape, not `max_length - embeds_length`.
# -1 because the last generated token isn't yet in the cache.
- cache_shape = (batch_size, num_key_value_heads, max_length - 1, head_dim)
- self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
- self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
- self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)
+ max_length = max_new_tokens + inputs_embeds.shape[1] - 1
+ cache_shape = [batch_size, num_key_value_heads, max_length, head_dim]
+ self.assertIsInstance(outputs.past_key_values, StaticCache)
+ self.assertEqual(len(outputs.past_key_values.key_cache), num_hidden_layers)
+ self.assertListEqual(list(outputs.past_key_values.key_cache[0].shape), cache_shape)
@pytest.mark.generate
def test_generate_continue_from_past_key_values(self):
@@ -2022,7 +2023,7 @@ def test_generate_with_static_cache(self):
config.is_decoder = True
batch_size = main_input.shape[0]
- seq_length = main_input.shape[-1]
+ seq_length = self.model_tester.seq_length
max_new_tokens = 20
for dtype in (torch.float32, torch.float16):
@@ -2134,7 +2135,15 @@ def test_generate_compile_model_forward(self):
# compilation-specific setup
torch.compiler.reset() # prevent cached compilation from being used in the test
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
- model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
+
+ # BLIP is the only exception with custom generate which call `self.lm.generate()`
+ # We should avoid such calls in all subsequent multimodal models and try to make `generate()`
+ # compatible with multimodality
+ if "blip" in model.__class__.__name__.lower():
+ model.language_model.generation_config.compile_config._compile_all_devices = True
+ else:
+ # force compilation (e.g. fast CI, CPU
+ model.generation_config.compile_config._compile_all_devices = True
generation_kwargs = {
"do_sample": False,
@@ -2175,7 +2184,14 @@ def test_generate_compile_model_forward(self):
)
self.assertFalse(isinstance(decoder_cache, DynamicCache))
self.assertTrue(decoder_cache.is_compileable)
- self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
+
+ # BLIP is the only exception with custom generate which call `self.lm.generate()`
+ # We should avoid such calls in all subsequent multimodal models and try to make `generate()`
+ # compatible with multimodality
+ if "blip" in model.__class__.__name__.lower():
+ self.assertTrue(hasattr(model.language_model, "_compiled_call"))
+ else:
+ self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
self._check_similar_generate_outputs(dynamic_result, compiled_result)
@@ -2198,9 +2214,19 @@ def test_generate_compilation_all_outputs(self):
# compilation-specific setup
torch.compiler.reset() # prevent cached compilation from being used in the test
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
- model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
- if not has_defined_cache_implementation:
- model.generation_config.cache_implementation = "static"
+
+ # BLIP is the only exception with custom generate which call `self.lm.generate()`
+ # We should avoid such calls in all subsequent multimodal models and try to make `generate()`
+ # compatible with multimodality
+ if "blip" in model.__class__.__name__.lower():
+ model.language_model.generation_config.compile_config._compile_all_devices = True
+ if not has_defined_cache_implementation:
+ model.language_model.generation_config.cache_implementation = "static"
+ else:
+ # force compilation (e.g. fast CI, CPU)
+ model.generation_config.compile_config._compile_all_devices = True
+ if not has_defined_cache_implementation:
+ model.generation_config.cache_implementation = "static"
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
output_generate = model.generate(
@@ -2218,8 +2244,10 @@ def test_generate_compilation_all_outputs(self):
**inputs_dict,
)
- # Sanity check: compilation has happened
- self.assertTrue(hasattr(model, "_compiled_call"))
+ if "blip" in model.__class__.__name__.lower():
+ self.assertTrue(hasattr(model.language_model, "_compiled_call"))
+ else:
+ self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py
index f12ff24b17f1..8b5e62de14c7 100644
--- a/tests/models/aria/test_modeling_aria.py
+++ b/tests/models/aria/test_modeling_aria.py
@@ -286,10 +286,18 @@ def test_generate_from_inputs_embeds_0_greedy(self):
def test_generate_from_inputs_embeds_1_beam_search(self):
pass
- @unittest.skip(reason="Unsupported")
+ @unittest.skip(reason="Dynamic control flow due to MoE")
def test_generate_with_static_cache(self):
pass
+ @unittest.skip(reason="Dynamic control flow due to MoE")
+ def test_generate_from_inputs_embeds_with_static_cache(self):
+ pass
+
+ @unittest.skip(reason="Dynamic control flow due to MoE")
+ def test_generate_compile_model_forward(self):
+ pass
+
@require_torch
class AriaForConditionalGenerationIntegrationTest(unittest.TestCase):
diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py
index e26232e3eb43..a405a1f97fb3 100644
--- a/tests/models/blip_2/test_modeling_blip_2.py
+++ b/tests/models/blip_2/test_modeling_blip_2.py
@@ -816,6 +816,10 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature):
def test_generate_from_inputs_embeds(self, _, num_beams):
pass
+ @unittest.skip("BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present")
+ def test_generate_from_inputs_embeds_with_static_cache(self):
+ pass
+
# this class is based on `T5ModelTester` found in tests/models/t5/test_modeling_t5.py
class Blip2TextModelTester:
diff --git a/tests/models/emu3/test_modeling_emu3.py b/tests/models/emu3/test_modeling_emu3.py
index 4563cc17dfce..491fd9f9ec4f 100644
--- a/tests/models/emu3/test_modeling_emu3.py
+++ b/tests/models/emu3/test_modeling_emu3.py
@@ -386,10 +386,6 @@ def test_disk_offload_bin(self):
def test_cpu_offload(self):
pass
- @unittest.skip("Doesn't work, tensors are not almost same") # TODO raushan fixme
- def test_custom_4d_attention_mask(self):
- pass
-
@unittest.skip("VQ-VAE module doesn't initialize weights properly")
def test_initialization(self):
pass
diff --git a/tests/models/got_ocr2/test_modeling_got_ocr2.py b/tests/models/got_ocr2/test_modeling_got_ocr2.py
index ac044de5ca96..178bec98ac62 100644
--- a/tests/models/got_ocr2/test_modeling_got_ocr2.py
+++ b/tests/models/got_ocr2/test_modeling_got_ocr2.py
@@ -256,12 +256,6 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
def test_past_key_values_format(self):
pass
- @unittest.skip(
- reason="GotOcr2 needs a dynamic control flow to pass pixel values to the forward function only in the first generation step"
- )
- def test_generate_compile_1_end_to_end(self):
- pass
-
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py
index 5d19f5b02025..32c45d6e71f7 100644
--- a/tests/models/idefics/test_modeling_idefics.py
+++ b/tests/models/idefics/test_modeling_idefics.py
@@ -838,6 +838,14 @@ def test_contrastive_generate_low_memory(self):
def test_custom_4d_attention_mask(self):
pass
+ @unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
+ def test_generate_with_static_cache(self):
+ pass
+
+ @unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs")
+ def test_generate_compile_model_forward(self):
+ pass
+
@unittest.skip(reason="We only test the model that takes in multiple images")
def test_model(self):
pass
diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py
index e072499ad3f1..bbf877289040 100644
--- a/tests/models/instructblip/test_modeling_instructblip.py
+++ b/tests/models/instructblip/test_modeling_instructblip.py
@@ -530,6 +530,12 @@ def test_save_load_fast_init_from_base(self):
def test_save_load_fast_init_to_base(self):
pass
+ @unittest.skip(
+ "InstructBLIP cannot generate only from input ids, and requires pixel values in all cases to be present"
+ )
+ def test_generate_from_inputs_embeds_with_static_cache(self):
+ pass
+
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
diff --git a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py
index 0534b4f5ea73..351dea3d6fae 100644
--- a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py
+++ b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py
@@ -546,6 +546,12 @@ def test_save_load_fast_init_from_base(self):
def test_save_load_fast_init_to_base(self):
pass
+ @unittest.skip(
+ "InstructBLIPVideo cannot generate only from input ids, and requires pixel values in all cases to be present"
+ )
+ def test_generate_from_inputs_embeds_with_static_cache(self):
+ pass
+
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py
index 25e1a747ce9f..b47423a02ec7 100644
--- a/tests/models/llava/test_modeling_llava.py
+++ b/tests/models/llava/test_modeling_llava.py
@@ -316,14 +316,6 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
- @unittest.skip(reason="Compile not yet supported because in LLava models")
- def test_sdpa_can_compile_dynamic(self):
- pass
-
- @unittest.skip(reason="Compile not yet supported because in LLava models")
- def test_sdpa_can_dispatch_on_flash(self):
- pass
-
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py
index eaeda3cecb7b..0c75df53c1bb 100644
--- a/tests/models/llava_next/test_modeling_llava_next.py
+++ b/tests/models/llava_next/test_modeling_llava_next.py
@@ -365,22 +365,6 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
- @unittest.skip(reason="Feedforward chunking is not yet supported")
- def test_feed_forward_chunking(self):
- pass
-
- @unittest.skip(reason="CPU offload is not yet supported")
- def test_cpu_offload(self):
- pass
-
- @unittest.skip(reason="Compile not yet supported because in LLava models")
- def test_sdpa_can_compile_dynamic(self):
- pass
-
- @unittest.skip(reason="Compile not yet supported because in LLava models")
- def test_sdpa_can_dispatch_on_flash(self):
- pass
-
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@@ -391,6 +375,10 @@ def test_flash_attn_2_fp32_ln(self):
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
+ @unittest.skip("LLaVA Next has dynamic control flow in unpadding")
+ def test_generate_compile_model_forward(self):
+ pass
+
@require_torch
class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py
index 0f4642402644..6d4df92f5c22 100644
--- a/tests/models/llava_next_video/test_modeling_llava_next_video.py
+++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py
@@ -382,26 +382,6 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
- @unittest.skip(reason="Feedforward chunking is not yet supported")
- def test_feed_forward_chunking(self):
- pass
-
- @unittest.skip(reason="CPU offload is not yet supported")
- def test_cpu_offload(self):
- pass
-
- @unittest.skip(
- reason="Compile not yet supported because in LLava models (https://github.com/huggingface/transformers/issues/29891)"
- )
- def test_sdpa_can_compile_dynamic(self):
- pass
-
- @unittest.skip(
- reason="Compile not yet supported because in LLava models (https://github.com/huggingface/transformers/issues/29891)"
- )
- def test_sdpa_can_dispatch_on_flash(self):
- pass
-
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
@@ -412,6 +392,10 @@ def test_flash_attn_2_fp32_ln(self):
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
+ @unittest.skip("LLaVA Next Video has dynamic control flow in unpadding")
+ def test_generate_compile_model_forward(self):
+ pass
+
@require_torch
class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase):
diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py
index 63be10a774db..c9bb448278e7 100644
--- a/tests/models/llava_onevision/test_modeling_llava_onevision.py
+++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py
@@ -346,6 +346,10 @@ def test_flash_attn_2_fp32_ln(self):
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
+ @unittest.skip("LLaVA OneVision has dynamic control flow in unpadding")
+ def test_generate_compile_model_forward(self):
+ pass
+
@require_torch
class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase):
diff --git a/tests/models/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py
index 3c3256da8b24..994d88444809 100644
--- a/tests/models/mt5/test_modeling_mt5.py
+++ b/tests/models/mt5/test_modeling_mt5.py
@@ -540,7 +540,6 @@ def prepare_config_and_inputs_for_common(self):
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
- "use_cache": False,
}
return config, inputs_dict
diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py
index 3e3d2159a022..dad740cde721 100644
--- a/tests/models/opt/test_modeling_opt.py
+++ b/tests/models/opt/test_modeling_opt.py
@@ -81,7 +81,7 @@ def __init__(
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
- max_position_embeddings=20,
+ max_position_embeddings=50,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
@@ -89,7 +89,6 @@ def __init__(
num_labels=3,
word_embed_proj_dim=16,
type_sequence_label_size=2,
- attn_implementation="eager",
):
self.parent = parent
self.batch_size = batch_size
@@ -113,7 +112,6 @@ def __init__(
self.type_sequence_label_size = type_sequence_label_size
self.word_embed_proj_dim = word_embed_proj_dim
self.is_encoder_decoder = False
- self.attn_implementation = attn_implementation
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
@@ -143,7 +141,6 @@ def get_config(self):
embed_dim=self.embed_dim,
is_encoder_decoder=False,
word_embed_proj_dim=self.word_embed_proj_dim,
- attn_implementation=self.attn_implementation,
)
def get_pipeline_config(self):
diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py
index 9886684d6088..a0439550f8f0 100644
--- a/tests/models/t5/test_modeling_t5.py
+++ b/tests/models/t5/test_modeling_t5.py
@@ -545,7 +545,6 @@ def prepare_config_and_inputs_for_common(self):
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
- "use_cache": False,
}
return config, inputs_dict
diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py
index b8d4d4167e57..528f125693f7 100644
--- a/tests/models/video_llava/test_modeling_video_llava.py
+++ b/tests/models/video_llava/test_modeling_video_llava.py
@@ -226,14 +226,6 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
- @unittest.skip(reason="Pass because video-LLava requires `attention_mask is not None`")
- def test_sdpa_can_compile_dynamic(self):
- pass
-
- @unittest.skip(reason="Pass because video-LLava requires `attention_mask is not None`")
- def test_sdpa_can_dispatch_on_flash(self):
- pass
-
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py
index f6a601c8a02d..24f99d4b0b18 100644
--- a/tests/models/vipllava/test_modeling_vipllava.py
+++ b/tests/models/vipllava/test_modeling_vipllava.py
@@ -306,14 +306,6 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
- @unittest.skip(reason="Compile not yet supported because it is not yet supported in LLava")
- def test_sdpa_can_compile_dynamic(self):
- pass
-
- @unittest.skip(reason="Compile not yet supported because in LLava models")
- def test_sdpa_can_dispatch_on_flash(self):
- pass
-
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
def test_flash_attn_2_fp32_ln(self):
pass
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index 9dd5877c8b90..a707b25a3110 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -4324,10 +4324,6 @@ def test_sdpa_can_dispatch_on_flash(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
- if config.model_type in ["llava", "llava_next", "vipllava", "video_llava"]:
- self.skipTest(
- reason="Llava-like models currently (transformers==4.39.1) requires an attention_mask input"
- )
if config.model_type in ["paligemma"]:
self.skipTest(
"PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input"
@@ -4778,6 +4774,9 @@ def test_custom_4d_attention_mask(self):
model = model_class(config).to(device=torch_device, dtype=torch.float32)
set_model_for_less_flaky_test(model)
+ if "position_ids" not in inspect.signature(model.forward).parameters:
+ continue # this model doesn't accept position ids as input
+
(
input_ids,
position_ids,