harness / diffs /37447.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py
index 29a30f3ab70e..89a017bfdbe4 100644
--- a/src/transformers/cache_utils.py
+++ b/src/transformers/cache_utils.py
@@ -1654,9 +1654,7 @@ class HybridCache(Cache):
```
"""
- # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
- # ALL changes from the PR that commented the line below when reactivating it.
- # is_compileable = True
+ is_compileable = True
def __init__(
self,
@@ -1858,8 +1856,6 @@ class HybridChunkedCache(Cache):
```
"""
- # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert
- # ALL changes from the PR that commented the line below when reactivating it.
is_compileable = True
def __init__(
diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py
index a7189a1a212b..1af14d021c7a 100644
--- a/src/transformers/models/cohere2/modeling_cohere2.py
+++ b/src/transformers/models/cohere2/modeling_cohere2.py
@@ -42,6 +42,7 @@
logging,
replace_return_docstrings,
)
+from ...utils.deprecation import deprecate_kwarg
from .configuration_cohere2 import Cohere2Config
@@ -300,6 +301,7 @@ def __init__(self, config: Cohere2Config, layer_idx: int):
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
self.sliding_window = config.sliding_window
+ @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
@@ -309,7 +311,6 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- last_cache_position: int = 0,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
@@ -330,7 +331,6 @@ def forward(
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
- last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
"""
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -349,11 +349,16 @@ def forward(
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
- # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
- offset = last_cache_position - effective_seq_len
+ offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
- attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
+ # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
+ # but without data-dependent slicing (i.e. torch.compile friendly)
+ mask_indexes = torch.arange(
+ min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
+ )
+ mask_indexes += offset
+ attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states
@@ -539,6 +544,7 @@ def set_input_embeddings(self, value):
@can_return_tuple
@add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING)
+ @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -550,7 +556,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -590,16 +595,6 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
- # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
- # (retrieving the same value from `cache_position` later on would crash dynamo)
- if last_cache_position is None:
- last_cache_position = 0
- if attention_mask is not None:
- # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
- # It will break dynamo tracing but there are no way around it (and it should never happen in practice)
- last_cache_position = (
- attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
- )
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
@@ -627,7 +622,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
- last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -638,7 +632,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
@@ -928,10 +921,6 @@ def prepare_inputs_for_generation(
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
- # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
- # (retrieving the same value from `cache_position` later on would crash dynamo)
- model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
-
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py
index 3d1bdaeca944..85a8d04a5091 100644
--- a/src/transformers/models/cohere2/modular_cohere2.py
+++ b/src/transformers/models/cohere2/modular_cohere2.py
@@ -23,15 +23,12 @@
from ...cache_utils import Cache, HybridCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
-from ...modeling_outputs import (
- BaseModelOutputWithPast,
-)
+from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
-from ...utils import (
- logging,
-)
+from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, logging
+from ...utils.deprecation import deprecate_kwarg
from ..cohere.modeling_cohere import (
CohereAttention,
CohereDecoderLayer,
@@ -45,6 +42,9 @@
from ..gemma2.modeling_gemma2 import Gemma2Model
+COHERE2_INPUTS_DOCSTRING = None # Will be picked up by modular
+
+
logger = logging.get_logger(__name__)
@@ -351,6 +351,7 @@ def __init__(self, config: Cohere2Config, layer_idx: int):
self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0
self.sliding_window = config.sliding_window
+ @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
@@ -360,7 +361,6 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- last_cache_position: int = 0,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
@@ -381,7 +381,6 @@ def forward(
(see `past_key_values`).
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
- last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
"""
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -400,11 +399,16 @@ def forward(
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
- # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
- offset = last_cache_position - effective_seq_len
+ offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
- attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
+ # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
+ # but without data-dependent slicing (i.e. torch.compile friendly)
+ mask_indexes = torch.arange(
+ min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
+ )
+ mask_indexes += offset
+ attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states
@@ -452,6 +456,9 @@ def __init__(self, config: Cohere2Config):
self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
self.rotary_emb = Cohere2RotaryEmbedding(config=config)
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING)
+ @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -463,7 +470,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -503,16 +509,6 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
- # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
- # (retrieving the same value from `cache_position` later on would crash dynamo)
- if last_cache_position is None:
- last_cache_position = 0
- if attention_mask is not None:
- # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
- # It will break dynamo tracing but there are no way around it (and it should never happen in practice)
- last_cache_position = (
- attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
- )
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
@@ -540,7 +536,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
- last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -551,7 +546,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
@@ -625,10 +619,6 @@ def prepare_inputs_for_generation(
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
- # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
- # (retrieving the same value from `cache_position` later on would crash dynamo)
- model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
-
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py
index f0d340048fb7..353b171042f3 100644
--- a/src/transformers/models/gemma2/modeling_gemma2.py
+++ b/src/transformers/models/gemma2/modeling_gemma2.py
@@ -47,6 +47,7 @@
logging,
replace_return_docstrings,
)
+from ...utils.deprecation import deprecate_kwarg
from .configuration_gemma2 import Gemma2Config
@@ -285,6 +286,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
+ @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
@@ -295,7 +297,6 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- last_cache_position: int = 0,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -314,11 +315,16 @@ def forward(
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
- # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
- offset = last_cache_position - effective_seq_len
+ offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
- attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
+ # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
+ # but without data-dependent slicing (i.e. torch.compile friendly)
+ mask_indexes = torch.arange(
+ min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
+ )
+ mask_indexes += offset
+ attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states
@@ -542,6 +548,7 @@ def set_input_embeddings(self, value):
@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
+ @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -553,7 +560,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -594,16 +600,6 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
- # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
- # (retrieving the same value from `cache_position` later on would crash dynamo)
- if last_cache_position is None:
- last_cache_position = 0
- if attention_mask is not None:
- # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
- # It will break dynamo tracing but there are no way around it (and it should never happen in practice)
- last_cache_position = (
- attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
- )
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
@@ -639,7 +635,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
- last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -651,7 +646,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
@@ -922,9 +916,6 @@ def prepare_inputs_for_generation(
**kwargs,
)
- # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
- # (retrieving the same value from `cache_position` later on would crash dynamo)
- model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None)
diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py
index e06a701fc527..b219384f34ab 100644
--- a/src/transformers/models/gemma2/modular_gemma2.py
+++ b/src/transformers/models/gemma2/modular_gemma2.py
@@ -24,13 +24,11 @@
from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
-from ...modeling_outputs import (
- BaseModelOutputWithPast,
- CausalLMOutputWithPast,
-)
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
-from ...utils import is_torch_flex_attn_available, logging
+from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, is_torch_flex_attn_available, logging
+from ...utils.deprecation import deprecate_kwarg
from ..gemma.modeling_gemma import (
GemmaAttention,
GemmaForCausalLM,
@@ -45,6 +43,7 @@
_CHECKPOINT_FOR_DOC = "google/gemma2-7b"
+GEMMA2_INPUTS_DOCSTRING = None # Will be picked up by modular
if is_torch_flex_attn_available():
@@ -334,6 +333,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
+ @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
@@ -344,7 +344,6 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- last_cache_position: int = 0,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -363,11 +362,16 @@ def forward(
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
- # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
- offset = last_cache_position - effective_seq_len
+ offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
- attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
+ # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
+ # but without data-dependent slicing (i.e. torch.compile friendly)
+ mask_indexes = torch.arange(
+ min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
+ )
+ mask_indexes += offset
+ attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states
@@ -409,6 +413,9 @@ def __init__(self, config: Gemma2Config):
[Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
+ @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -420,7 +427,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -461,16 +467,6 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
- # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
- # (retrieving the same value from `cache_position` later on would crash dynamo)
- if last_cache_position is None:
- last_cache_position = 0
- if attention_mask is not None:
- # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
- # It will break dynamo tracing but there are no way around it (and it should never happen in practice)
- last_cache_position = (
- attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
- )
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
@@ -506,7 +502,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
- last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -518,7 +513,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
@@ -702,9 +696,6 @@ def prepare_inputs_for_generation(
**kwargs,
)
- # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
- # (retrieving the same value from `cache_position` later on would crash dynamo)
- model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None)
diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py
index 50ca08a3f10a..170e3d952f31 100644
--- a/src/transformers/models/gemma3/modeling_gemma3.py
+++ b/src/transformers/models/gemma3/modeling_gemma3.py
@@ -45,6 +45,7 @@
logging,
replace_return_docstrings,
)
+from ...utils.deprecation import deprecate_kwarg
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
@@ -377,6 +378,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int):
self.is_sliding = self.self_attn.is_sliding
self.sliding_window = config.sliding_window
+ @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
@@ -388,7 +390,6 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- last_cache_position: int = 0,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -407,11 +408,16 @@ def forward(
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
- # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
- offset = last_cache_position - effective_seq_len
+ offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
- attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
+ # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
+ # but without data-dependent slicing (i.e. torch.compile friendly)
+ mask_indexes = torch.arange(
+ min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
+ )
+ mask_indexes += offset
+ attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states
@@ -626,6 +632,7 @@ def set_input_embeddings(self, value):
@can_return_tuple
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
+ @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -637,7 +644,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -678,16 +684,6 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
- # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
- # (retrieving the same value from `cache_position` later on would crash dynamo)
- if last_cache_position is None:
- last_cache_position = 0
- if attention_mask is not None:
- # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
- # It will break dynamo tracing but there are no way around it (and it should never happen in practice)
- last_cache_position = (
- attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
- )
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
@@ -723,7 +719,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
- last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -736,7 +731,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
@@ -1009,9 +1003,6 @@ def prepare_inputs_for_generation(
**kwargs,
)
- # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
- # (retrieving the same value from `cache_position` later on would crash dynamo)
- model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
if logits_to_keep is None:
_ = model_inputs.pop("logits_to_keep", None)
diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py
index f2e716f21628..7c95f63b0e04 100644
--- a/src/transformers/models/gemma3/modular_gemma3.py
+++ b/src/transformers/models/gemma3/modular_gemma3.py
@@ -26,11 +26,7 @@
from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
-from ...modeling_outputs import (
- BaseModelOutputWithPast,
- CausalLMOutputWithPast,
- ModelOutput,
-)
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
@@ -41,6 +37,7 @@
logging,
replace_return_docstrings,
)
+from ...utils.deprecation import deprecate_kwarg
from ..gemma2.configuration_gemma2 import Gemma2Config
from ..gemma2.modeling_gemma2 import (
Gemma2Attention,
@@ -460,6 +457,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int):
self.is_sliding = self.self_attn.is_sliding
self.sliding_window = config.sliding_window
+ @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
hidden_states: torch.Tensor,
@@ -471,7 +469,6 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
- last_cache_position: int = 0,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
@@ -490,11 +487,16 @@ def forward(
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
- # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
- offset = last_cache_position - effective_seq_len
+ offset = cache_position[-1] - effective_seq_len + 1
# Should only be used when beyond the sliding window (i.e. offset > 0)
offset = max(0, offset)
- attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
+ # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`,
+ # but without data-dependent slicing (i.e. torch.compile friendly)
+ mask_indexes = torch.arange(
+ min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device
+ )
+ mask_indexes += offset
+ attention_mask = attention_mask[:, :, :, mask_indexes]
residual = hidden_states
@@ -581,6 +583,9 @@ def __init__(self, config: Gemma3TextConfig):
config.rope_scaling = {"rope_type": "default"}
self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
+ @deprecate_kwarg("last_cache_position", version="4.53.0")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -592,7 +597,6 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
- last_cache_position: Optional[int] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -633,16 +637,6 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
- # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
- # (retrieving the same value from `cache_position` later on would crash dynamo)
- if last_cache_position is None:
- last_cache_position = 0
- if attention_mask is not None:
- # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
- # It will break dynamo tracing but there are no way around it (and it should never happen in practice)
- last_cache_position = (
- attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
- )
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
@@ -678,7 +672,6 @@ def forward(
output_attentions,
use_cache,
cache_position,
- last_cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -691,7 +684,6 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
- last_cache_position=last_cache_position,
**flash_attn_kwargs,
)
diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py
index 0672589769ad..fa8bd274cce2 100644
--- a/tests/generation/test_utils.py
+++ b/tests/generation/test_utils.py
@@ -2075,9 +2075,6 @@ def test_generate_compile_model_forward(self):
Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results.
⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️
"""
- # Monkey-patching the HybridCache at test-time to continue testing compilation support
- HybridCache.is_compileable = True
-
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
@@ -2174,9 +2171,6 @@ def test_generate_compilation_all_outputs(self):
Tests that all optional outputs are behaving as expected when compilation is triggered.
In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered.
"""
- # Monkey-patching the HybridCache at test-time to continue testing compilation support
- HybridCache.is_compileable = True
-
for model_class in self.all_generative_model_classes:
if not model_class._supports_static_cache:
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py
index 3d396d8f7f38..c6277af594ca 100644
--- a/tests/models/gemma2/test_modeling_gemma2.py
+++ b/tests/models/gemma2/test_modeling_gemma2.py
@@ -153,6 +153,10 @@ def test_sdpa_equivalence(self):
def test_multi_gpu_data_parallel_forward(self):
pass
+ @unittest.skip("Gemma2 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
+ def test_eager_matches_fa2_generate(self):
+ pass
+
@slow
@require_torch_gpu
diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py
index 935d8b884a11..be83749cf8bc 100644
--- a/tests/models/gemma3/test_modeling_gemma3.py
+++ b/tests/models/gemma3/test_modeling_gemma3.py
@@ -329,6 +329,10 @@ def test_generate_with_static_cache(self):
def test_generate_from_inputs_embeds_with_static_cache(self):
pass
+ @unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.")
+ def test_eager_matches_fa2_generate(self):
+ pass
+
@unittest.skip(
reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation"
)