| |
| |
| |
| |
| @@ -1654,9 +1654,7 @@ class HybridCache(Cache): |
| ``` |
| """ |
| |
| - # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert |
| - # ALL changes from the PR that commented the line below when reactivating it. |
| - # is_compileable = True |
| + is_compileable = True |
| |
| def __init__( |
| self, |
| @@ -1858,8 +1856,6 @@ class HybridChunkedCache(Cache): |
| ``` |
| """ |
| |
| - # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert |
| - # ALL changes from the PR that commented the line below when reactivating it. |
| is_compileable = True |
| |
| def __init__( |
| |
| |
| |
| |
| @@ -42,6 +42,7 @@ |
| logging, |
| replace_return_docstrings, |
| ) |
| +from ...utils.deprecation import deprecate_kwarg |
| from .configuration_cohere2 import Cohere2Config |
| |
| |
| @@ -300,6 +301,7 @@ def __init__(self, config: Cohere2Config, layer_idx: int): |
| self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0 |
| self.sliding_window = config.sliding_window |
| |
| + @deprecate_kwarg("last_cache_position", version="4.53.0") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| @@ -309,7 +311,6 @@ def forward( |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| - last_cache_position: int = 0, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| """ |
| @@ -330,7 +331,6 @@ def forward( |
| (see `past_key_values`). |
| cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| Indices depicting the position of the input sequence tokens in the sequence |
| - last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing |
| """ |
| |
| if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding |
| @@ -349,11 +349,16 @@ def forward( |
| ) |
| attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) |
| # In case we are beyond the sliding window, we need to correctly offset the mask slicing |
| - # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo |
| - offset = last_cache_position - effective_seq_len |
| + offset = cache_position[-1] - effective_seq_len + 1 |
| # Should only be used when beyond the sliding window (i.e. offset > 0) |
| offset = max(0, offset) |
| - attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] |
| + # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, |
| + # but without data-dependent slicing (i.e. torch.compile friendly) |
| + mask_indexes = torch.arange( |
| + min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device |
| + ) |
| + mask_indexes += offset |
| + attention_mask = attention_mask[:, :, :, mask_indexes] |
| |
| residual = hidden_states |
| |
| @@ -539,6 +544,7 @@ def set_input_embeddings(self, value): |
| |
| @can_return_tuple |
| @add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING) |
| + @deprecate_kwarg("last_cache_position", version="4.53.0") |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| @@ -550,7 +556,6 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| - last_cache_position: Optional[int] = None, |
| **flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
| ) -> BaseModelOutputWithPast: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| @@ -590,16 +595,6 @@ def forward( |
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
| |
| - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing |
| - # (retrieving the same value from `cache_position` later on would crash dynamo) |
| - if last_cache_position is None: |
| - last_cache_position = 0 |
| - if attention_mask is not None: |
| - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position |
| - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) |
| - last_cache_position = ( |
| - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() |
| - ) |
| causal_mask = self._update_causal_mask( |
| attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
| ) |
| @@ -627,7 +622,6 @@ def forward( |
| output_attentions, |
| use_cache, |
| cache_position, |
| - last_cache_position, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| @@ -638,7 +632,6 @@ def forward( |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| - last_cache_position=last_cache_position, |
| **flash_attn_kwargs, |
| ) |
| |
| @@ -928,10 +921,6 @@ def prepare_inputs_for_generation( |
| # The clone here is for the same reason as for `position_ids`. |
| model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} |
| |
| - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing |
| - # (retrieving the same value from `cache_position` later on would crash dynamo) |
| - model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 |
| - |
| if ( |
| isinstance(past_key_values, HybridCache) |
| and attention_mask.ndim == 2 |
| |
| |
| |
| |
| @@ -23,15 +23,12 @@ |
| from ...cache_utils import Cache, HybridCache |
| from ...configuration_utils import PretrainedConfig |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| -from ...modeling_outputs import ( |
| - BaseModelOutputWithPast, |
| -) |
| +from ...modeling_outputs import BaseModelOutputWithPast |
| from ...modeling_rope_utils import rope_config_validation |
| from ...modeling_utils import ALL_ATTENTION_FUNCTIONS |
| from ...processing_utils import Unpack |
| -from ...utils import ( |
| - logging, |
| -) |
| +from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, logging |
| +from ...utils.deprecation import deprecate_kwarg |
| from ..cohere.modeling_cohere import ( |
| CohereAttention, |
| CohereDecoderLayer, |
| @@ -45,6 +42,9 @@ |
| from ..gemma2.modeling_gemma2 import Gemma2Model |
| |
| |
| +COHERE2_INPUTS_DOCSTRING = None # Will be picked up by modular |
| + |
| + |
| logger = logging.get_logger(__name__) |
| |
| |
| @@ -351,6 +351,7 @@ def __init__(self, config: Cohere2Config, layer_idx: int): |
| self.is_sliding = (layer_idx + 1) % self.config.sliding_window_pattern != 0 |
| self.sliding_window = config.sliding_window |
| |
| + @deprecate_kwarg("last_cache_position", version="4.53.0") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| @@ -360,7 +361,6 @@ def forward( |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| - last_cache_position: int = 0, |
| **kwargs: Unpack[FlashAttentionKwargs], |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| """ |
| @@ -381,7 +381,6 @@ def forward( |
| (see `past_key_values`). |
| cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| Indices depicting the position of the input sequence tokens in the sequence |
| - last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing |
| """ |
| |
| if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding |
| @@ -400,11 +399,16 @@ def forward( |
| ) |
| attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) |
| # In case we are beyond the sliding window, we need to correctly offset the mask slicing |
| - # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo |
| - offset = last_cache_position - effective_seq_len |
| + offset = cache_position[-1] - effective_seq_len + 1 |
| # Should only be used when beyond the sliding window (i.e. offset > 0) |
| offset = max(0, offset) |
| - attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] |
| + # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, |
| + # but without data-dependent slicing (i.e. torch.compile friendly) |
| + mask_indexes = torch.arange( |
| + min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device |
| + ) |
| + mask_indexes += offset |
| + attention_mask = attention_mask[:, :, :, mask_indexes] |
| |
| residual = hidden_states |
| |
| @@ -452,6 +456,9 @@ def __init__(self, config: Cohere2Config): |
| self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) |
| self.rotary_emb = Cohere2RotaryEmbedding(config=config) |
| |
| + @can_return_tuple |
| + @add_start_docstrings_to_model_forward(COHERE2_INPUTS_DOCSTRING) |
| + @deprecate_kwarg("last_cache_position", version="4.53.0") |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| @@ -463,7 +470,6 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| - last_cache_position: Optional[int] = None, |
| **flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
| ) -> BaseModelOutputWithPast: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| @@ -503,16 +509,6 @@ def forward( |
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
| |
| - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing |
| - # (retrieving the same value from `cache_position` later on would crash dynamo) |
| - if last_cache_position is None: |
| - last_cache_position = 0 |
| - if attention_mask is not None: |
| - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position |
| - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) |
| - last_cache_position = ( |
| - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() |
| - ) |
| causal_mask = self._update_causal_mask( |
| attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
| ) |
| @@ -540,7 +536,6 @@ def forward( |
| output_attentions, |
| use_cache, |
| cache_position, |
| - last_cache_position, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| @@ -551,7 +546,6 @@ def forward( |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| - last_cache_position=last_cache_position, |
| **flash_attn_kwargs, |
| ) |
| |
| @@ -625,10 +619,6 @@ def prepare_inputs_for_generation( |
| # The clone here is for the same reason as for `position_ids`. |
| model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} |
| |
| - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing |
| - # (retrieving the same value from `cache_position` later on would crash dynamo) |
| - model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 |
| - |
| if ( |
| isinstance(past_key_values, HybridCache) |
| and attention_mask.ndim == 2 |
| |
| |
| |
| |
| @@ -47,6 +47,7 @@ |
| logging, |
| replace_return_docstrings, |
| ) |
| +from ...utils.deprecation import deprecate_kwarg |
| from .configuration_gemma2 import Gemma2Config |
| |
| |
| @@ -285,6 +286,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): |
| self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.sliding_window = config.sliding_window |
| |
| + @deprecate_kwarg("last_cache_position", version="4.53.0") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| @@ -295,7 +297,6 @@ def forward( |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| - last_cache_position: int = 0, |
| **kwargs, |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding |
| @@ -314,11 +315,16 @@ def forward( |
| ) |
| attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) |
| # In case we are beyond the sliding window, we need to correctly offset the mask slicing |
| - # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo |
| - offset = last_cache_position - effective_seq_len |
| + offset = cache_position[-1] - effective_seq_len + 1 |
| # Should only be used when beyond the sliding window (i.e. offset > 0) |
| offset = max(0, offset) |
| - attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] |
| + # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, |
| + # but without data-dependent slicing (i.e. torch.compile friendly) |
| + mask_indexes = torch.arange( |
| + min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device |
| + ) |
| + mask_indexes += offset |
| + attention_mask = attention_mask[:, :, :, mask_indexes] |
| |
| residual = hidden_states |
| |
| @@ -542,6 +548,7 @@ def set_input_embeddings(self, value): |
| |
| @can_return_tuple |
| @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) |
| + @deprecate_kwarg("last_cache_position", version="4.53.0") |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| @@ -553,7 +560,6 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| - last_cache_position: Optional[int] = None, |
| **flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
| ) -> BaseModelOutputWithPast: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| @@ -594,16 +600,6 @@ def forward( |
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
| |
| - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing |
| - # (retrieving the same value from `cache_position` later on would crash dynamo) |
| - if last_cache_position is None: |
| - last_cache_position = 0 |
| - if attention_mask is not None: |
| - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position |
| - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) |
| - last_cache_position = ( |
| - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() |
| - ) |
| causal_mask = self._update_causal_mask( |
| attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
| ) |
| @@ -639,7 +635,6 @@ def forward( |
| output_attentions, |
| use_cache, |
| cache_position, |
| - last_cache_position, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| @@ -651,7 +646,6 @@ def forward( |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| - last_cache_position=last_cache_position, |
| **flash_attn_kwargs, |
| ) |
| |
| @@ -922,9 +916,6 @@ def prepare_inputs_for_generation( |
| **kwargs, |
| ) |
| |
| - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing |
| - # (retrieving the same value from `cache_position` later on would crash dynamo) |
| - model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 |
| if logits_to_keep is None: |
| _ = model_inputs.pop("logits_to_keep", None) |
| |
| |
| |
| |
| |
| @@ -24,13 +24,11 @@ |
| from ...cache_utils import Cache, HybridCache, StaticCache |
| from ...configuration_utils import PretrainedConfig |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| -from ...modeling_outputs import ( |
| - BaseModelOutputWithPast, |
| - CausalLMOutputWithPast, |
| -) |
| +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from ...modeling_utils import ALL_ATTENTION_FUNCTIONS |
| from ...processing_utils import Unpack |
| -from ...utils import is_torch_flex_attn_available, logging |
| +from ...utils import add_start_docstrings_to_model_forward, can_return_tuple, is_torch_flex_attn_available, logging |
| +from ...utils.deprecation import deprecate_kwarg |
| from ..gemma.modeling_gemma import ( |
| GemmaAttention, |
| GemmaForCausalLM, |
| @@ -45,6 +43,7 @@ |
| |
| |
| _CHECKPOINT_FOR_DOC = "google/gemma2-7b" |
| +GEMMA2_INPUTS_DOCSTRING = None # Will be picked up by modular |
| |
| |
| if is_torch_flex_attn_available(): |
| @@ -334,6 +333,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): |
| self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.sliding_window = config.sliding_window |
| |
| + @deprecate_kwarg("last_cache_position", version="4.53.0") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| @@ -344,7 +344,6 @@ def forward( |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| - last_cache_position: int = 0, |
| **kwargs, |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding |
| @@ -363,11 +362,16 @@ def forward( |
| ) |
| attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) |
| # In case we are beyond the sliding window, we need to correctly offset the mask slicing |
| - # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo |
| - offset = last_cache_position - effective_seq_len |
| + offset = cache_position[-1] - effective_seq_len + 1 |
| # Should only be used when beyond the sliding window (i.e. offset > 0) |
| offset = max(0, offset) |
| - attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] |
| + # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, |
| + # but without data-dependent slicing (i.e. torch.compile friendly) |
| + mask_indexes = torch.arange( |
| + min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device |
| + ) |
| + mask_indexes += offset |
| + attention_mask = attention_mask[:, :, :, mask_indexes] |
| |
| residual = hidden_states |
| |
| @@ -409,6 +413,9 @@ def __init__(self, config: Gemma2Config): |
| [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| |
| + @can_return_tuple |
| + @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) |
| + @deprecate_kwarg("last_cache_position", version="4.53.0") |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| @@ -420,7 +427,6 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| - last_cache_position: Optional[int] = None, |
| **flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
| ) -> BaseModelOutputWithPast: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| @@ -461,16 +467,6 @@ def forward( |
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
| |
| - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing |
| - # (retrieving the same value from `cache_position` later on would crash dynamo) |
| - if last_cache_position is None: |
| - last_cache_position = 0 |
| - if attention_mask is not None: |
| - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position |
| - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) |
| - last_cache_position = ( |
| - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() |
| - ) |
| causal_mask = self._update_causal_mask( |
| attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
| ) |
| @@ -506,7 +502,6 @@ def forward( |
| output_attentions, |
| use_cache, |
| cache_position, |
| - last_cache_position, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| @@ -518,7 +513,6 @@ def forward( |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| - last_cache_position=last_cache_position, |
| **flash_attn_kwargs, |
| ) |
| |
| @@ -702,9 +696,6 @@ def prepare_inputs_for_generation( |
| **kwargs, |
| ) |
| |
| - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing |
| - # (retrieving the same value from `cache_position` later on would crash dynamo) |
| - model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 |
| if logits_to_keep is None: |
| _ = model_inputs.pop("logits_to_keep", None) |
| |
| |
| |
| |
| |
| @@ -45,6 +45,7 @@ |
| logging, |
| replace_return_docstrings, |
| ) |
| +from ...utils.deprecation import deprecate_kwarg |
| from ..auto import AutoModel, AutoModelForCausalLM |
| from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig |
| |
| @@ -377,6 +378,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int): |
| self.is_sliding = self.self_attn.is_sliding |
| self.sliding_window = config.sliding_window |
| |
| + @deprecate_kwarg("last_cache_position", version="4.53.0") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| @@ -388,7 +390,6 @@ def forward( |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| - last_cache_position: int = 0, |
| **kwargs, |
| ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding |
| @@ -407,11 +408,16 @@ def forward( |
| ) |
| attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) |
| # In case we are beyond the sliding window, we need to correctly offset the mask slicing |
| - # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo |
| - offset = last_cache_position - effective_seq_len |
| + offset = cache_position[-1] - effective_seq_len + 1 |
| # Should only be used when beyond the sliding window (i.e. offset > 0) |
| offset = max(0, offset) |
| - attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] |
| + # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, |
| + # but without data-dependent slicing (i.e. torch.compile friendly) |
| + mask_indexes = torch.arange( |
| + min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device |
| + ) |
| + mask_indexes += offset |
| + attention_mask = attention_mask[:, :, :, mask_indexes] |
| |
| residual = hidden_states |
| |
| @@ -626,6 +632,7 @@ def set_input_embeddings(self, value): |
| |
| @can_return_tuple |
| @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) |
| + @deprecate_kwarg("last_cache_position", version="4.53.0") |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| @@ -637,7 +644,6 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| - last_cache_position: Optional[int] = None, |
| **flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
| ) -> BaseModelOutputWithPast: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| @@ -678,16 +684,6 @@ def forward( |
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
| |
| - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing |
| - # (retrieving the same value from `cache_position` later on would crash dynamo) |
| - if last_cache_position is None: |
| - last_cache_position = 0 |
| - if attention_mask is not None: |
| - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position |
| - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) |
| - last_cache_position = ( |
| - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() |
| - ) |
| causal_mask = self._update_causal_mask( |
| attention_mask, |
| inputs_embeds, |
| @@ -723,7 +719,6 @@ def forward( |
| output_attentions, |
| use_cache, |
| cache_position, |
| - last_cache_position, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| @@ -736,7 +731,6 @@ def forward( |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| - last_cache_position=last_cache_position, |
| **flash_attn_kwargs, |
| ) |
| |
| @@ -1009,9 +1003,6 @@ def prepare_inputs_for_generation( |
| **kwargs, |
| ) |
| |
| - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing |
| - # (retrieving the same value from `cache_position` later on would crash dynamo) |
| - model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0 |
| if logits_to_keep is None: |
| _ = model_inputs.pop("logits_to_keep", None) |
| |
| |
| |
| |
| |
| @@ -26,11 +26,7 @@ |
| from ...cache_utils import Cache, HybridCache, StaticCache |
| from ...configuration_utils import PretrainedConfig |
| from ...modeling_flash_attention_utils import FlashAttentionKwargs |
| -from ...modeling_outputs import ( |
| - BaseModelOutputWithPast, |
| - CausalLMOutputWithPast, |
| - ModelOutput, |
| -) |
| +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput |
| from ...modeling_rope_utils import rope_config_validation |
| from ...modeling_utils import ALL_ATTENTION_FUNCTIONS |
| from ...processing_utils import Unpack |
| @@ -41,6 +37,7 @@ |
| logging, |
| replace_return_docstrings, |
| ) |
| +from ...utils.deprecation import deprecate_kwarg |
| from ..gemma2.configuration_gemma2 import Gemma2Config |
| from ..gemma2.modeling_gemma2 import ( |
| Gemma2Attention, |
| @@ -460,6 +457,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int): |
| self.is_sliding = self.self_attn.is_sliding |
| self.sliding_window = config.sliding_window |
| |
| + @deprecate_kwarg("last_cache_position", version="4.53.0") |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| @@ -471,7 +469,6 @@ def forward( |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| - last_cache_position: int = 0, |
| **kwargs, |
| ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding |
| @@ -490,11 +487,16 @@ def forward( |
| ) |
| attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) |
| # In case we are beyond the sliding window, we need to correctly offset the mask slicing |
| - # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo |
| - offset = last_cache_position - effective_seq_len |
| + offset = cache_position[-1] - effective_seq_len + 1 |
| # Should only be used when beyond the sliding window (i.e. offset > 0) |
| offset = max(0, offset) |
| - attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len] |
| + # equivalent to: `attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]`, |
| + # but without data-dependent slicing (i.e. torch.compile friendly) |
| + mask_indexes = torch.arange( |
| + min(effective_seq_len, attention_mask.shape[-1]), device=attention_mask.device |
| + ) |
| + mask_indexes += offset |
| + attention_mask = attention_mask[:, :, :, mask_indexes] |
| |
| residual = hidden_states |
| |
| @@ -581,6 +583,9 @@ def __init__(self, config: Gemma3TextConfig): |
| config.rope_scaling = {"rope_type": "default"} |
| self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) |
| |
| + @can_return_tuple |
| + @add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) |
| + @deprecate_kwarg("last_cache_position", version="4.53.0") |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| @@ -592,7 +597,6 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| - last_cache_position: Optional[int] = None, |
| **flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
| ) -> BaseModelOutputWithPast: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| @@ -633,16 +637,6 @@ def forward( |
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
| |
| - # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing |
| - # (retrieving the same value from `cache_position` later on would crash dynamo) |
| - if last_cache_position is None: |
| - last_cache_position = 0 |
| - if attention_mask is not None: |
| - # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position |
| - # It will break dynamo tracing but there are no way around it (and it should never happen in practice) |
| - last_cache_position = ( |
| - attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item() |
| - ) |
| causal_mask = self._update_causal_mask( |
| attention_mask, |
| inputs_embeds, |
| @@ -678,7 +672,6 @@ def forward( |
| output_attentions, |
| use_cache, |
| cache_position, |
| - last_cache_position, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| @@ -691,7 +684,6 @@ def forward( |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| - last_cache_position=last_cache_position, |
| **flash_attn_kwargs, |
| ) |
| |
| |
| |
| |
| |
| @@ -2075,9 +2075,6 @@ def test_generate_compile_model_forward(self): |
| Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. |
| ⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️ |
| """ |
| - # Monkey-patching the HybridCache at test-time to continue testing compilation support |
| - HybridCache.is_compileable = True |
| - |
| for model_class in self.all_generative_model_classes: |
| if not model_class._supports_static_cache: |
| self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") |
| @@ -2174,9 +2171,6 @@ def test_generate_compilation_all_outputs(self): |
| Tests that all optional outputs are behaving as expected when compilation is triggered. |
| In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered. |
| """ |
| - # Monkey-patching the HybridCache at test-time to continue testing compilation support |
| - HybridCache.is_compileable = True |
| - |
| for model_class in self.all_generative_model_classes: |
| if not model_class._supports_static_cache: |
| self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") |
| |
| |
| |
| |
| @@ -153,6 +153,10 @@ def test_sdpa_equivalence(self): |
| def test_multi_gpu_data_parallel_forward(self): |
| pass |
| |
| + @unittest.skip("Gemma2 has HybridCache which auto-compiles. Compile and FA2 don't work together.") |
| + def test_eager_matches_fa2_generate(self): |
| + pass |
| + |
| |
| @slow |
| @require_torch_gpu |
| |
| |
| |
| |
| @@ -329,6 +329,10 @@ def test_generate_with_static_cache(self): |
| def test_generate_from_inputs_embeds_with_static_cache(self): |
| pass |
| |
| + @unittest.skip("Gemma3 has HybridCache which auto-compiles. Compile and FA2 don't work together.") |
| + def test_eager_matches_fa2_generate(self): |
| + pass |
| + |
| @unittest.skip( |
| reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation" |
| ) |
|
|