harness / diffs /39505.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py
index 19b059699e28..e56eeec7d75b 100644
--- a/examples/modular-transformers/modeling_my_new_model2.py
+++ b/examples/modular-transformers/modeling_my_new_model2.py
@@ -294,7 +294,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": MyNewModel2DecoderLayer,
diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py
index 9111883cfef6..2a3df8e9c1d0 100644
--- a/examples/modular-transformers/modeling_new_task_model.py
+++ b/examples/modular-transformers/modeling_new_task_model.py
@@ -94,7 +94,7 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True
_supports_quantized_cache = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py
index fc90cce75a5f..ee90750cac25 100644
--- a/examples/modular-transformers/modeling_super.py
+++ b/examples/modular-transformers/modeling_super.py
@@ -293,7 +293,7 @@ class SuperPreTrainedModel(PreTrainedModel):
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": SuperDecoderLayer,
diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index e360acdac341..d20890ba9583 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -2042,7 +2042,7 @@ def _prepare_cache_for_generation(
)
if generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
- if generation_config.cache_implementation == "static" and not self._supports_static_cache:
+ if generation_config.cache_implementation == "static" and not self._can_compile_fullgraph:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
@@ -2198,7 +2198,8 @@ def _valid_auto_compile_criteria(self, model_kwargs: dict, generation_config: Ge
using_compilable_cache = (
isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable
)
- can_compile = valid_hardware and using_compilable_cache and self._supports_static_cache
+ # TODO @raushan `self._can_compile_fullgraph` can be removed and inferred from model arch (e.g. MoE doesn't support compile)
+ can_compile = valid_hardware and using_compilable_cache and self._can_compile_fullgraph
# Exception 1: Some quantization methods do not support compilation
if getattr(self, "hf_quantizer", None) is not None:
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 56e4145250a0..3cb186a3a317 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -2062,8 +2062,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# Flex Attention support
_supports_flex_attn = False
- # Has support `torch.compile(fullgraph=True)`
- _supports_static_cache = False
+ _can_compile_fullgraph = False
# A tensor parallel plan to be applied to the model when TP is enabled. For
# top-level models, this attribute is currently defined in respective model
diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py
index 43a02ebd8cb8..388ad6a392ce 100644
--- a/src/transformers/models/arcee/modeling_arcee.py
+++ b/src/transformers/models/arcee/modeling_arcee.py
@@ -317,7 +317,7 @@ class ArceePreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": ArceeDecoderLayer,
diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py
index 6c5c972b1f9a..68f1df91e1fb 100644
--- a/src/transformers/models/aria/modeling_aria.py
+++ b/src/transformers/models/aria/modeling_aria.py
@@ -664,7 +664,7 @@ class AriaPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": AriaTextDecoderLayer,
diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py
index f6303b4d382e..93a77e2b8631 100644
--- a/src/transformers/models/aria/modular_aria.py
+++ b/src/transformers/models/aria/modular_aria.py
@@ -1312,7 +1312,7 @@ def _init_weights(self, module):
class AriaPreTrainedModel(LlamaPreTrainedModel):
config: AriaConfig
base_model_prefix = ""
- _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing)
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
_supports_attention_backend = True
def _init_weights(self, module):
diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py
index cb4847068778..ba2095a88846 100644
--- a/src/transformers/models/aya_vision/modeling_aya_vision.py
+++ b/src/transformers/models/aya_vision/modeling_aya_vision.py
@@ -96,7 +96,7 @@ class AyaVisionPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = False
+ _can_compile_fullgraph = False
_supports_flex_attn = True
_supports_attention_backend = True
diff --git a/src/transformers/models/aya_vision/modular_aya_vision.py b/src/transformers/models/aya_vision/modular_aya_vision.py
index 58c118d73fad..53962d07d8fc 100644
--- a/src/transformers/models/aya_vision/modular_aya_vision.py
+++ b/src/transformers/models/aya_vision/modular_aya_vision.py
@@ -90,7 +90,7 @@ def pixel_shuffle(self, image_features): # B, S, D
class AyaVisionPreTrainedModel(LlavaPreTrainedModel):
- _supports_static_cache = False
+ _can_compile_fullgraph = False
def _init_weights(self, module):
std = (
diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py
index 055b696405dc..c4c007c42aa5 100755
--- a/src/transformers/models/bart/modeling_bart.py
+++ b/src/transformers/models/bart/modeling_bart.py
@@ -493,7 +493,7 @@ class BartPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
std = self.config.init_std
diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
index d567808f95af..64916fcb47e6 100755
--- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
+++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
@@ -1565,7 +1565,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_param_buffer_assignment = False
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
std = self.config.init_std
diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py
index a873cd6b6967..61576ac5d1bb 100755
--- a/src/transformers/models/biogpt/modeling_biogpt.py
+++ b/src/transformers/models/biogpt/modeling_biogpt.py
@@ -347,7 +347,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py
index 3f63caddeaee..e3e17f4e944e 100644
--- a/src/transformers/models/biogpt/modular_biogpt.py
+++ b/src/transformers/models/biogpt/modular_biogpt.py
@@ -172,7 +172,7 @@ class BioGptPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py
index c373d659c743..cd1ac751d793 100644
--- a/src/transformers/models/bitnet/modeling_bitnet.py
+++ b/src/transformers/models/bitnet/modeling_bitnet.py
@@ -312,7 +312,7 @@ class BitNetPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": BitNetDecoderLayer,
diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py
index 65f0378ef531..aee192473aef 100755
--- a/src/transformers/models/blenderbot/modeling_blenderbot.py
+++ b/src/transformers/models/blenderbot/modeling_blenderbot.py
@@ -458,7 +458,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
std = self.config.init_std
diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
index 9030bd1e5ce6..c90d168c5889 100755
--- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
+++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
@@ -451,7 +451,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
std = self.config.init_std
diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py
index 4c7a52e6fbbc..b19ae2f8dc44 100644
--- a/src/transformers/models/blip_2/modeling_blip_2.py
+++ b/src/transformers/models/blip_2/modeling_blip_2.py
@@ -1831,7 +1831,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
config: Blip2Config
main_input_name = "pixel_values"
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_keep_in_fp32_modules = ["query_tokens", "qformer"]
_supports_flash_attn = False # because self.qformer does not support FA2
diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py
index f999872bef55..cc8cd4eae90b 100644
--- a/src/transformers/models/bloom/modeling_bloom.py
+++ b/src/transformers/models/bloom/modeling_bloom.py
@@ -434,7 +434,7 @@ class BloomPreTrainedModel(PreTrainedModel):
_no_split_modules = ["BloomBlock"]
_skip_keys_device_placement = "past_key_values"
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py
index 5d8e6fc21073..b7cfe29119fb 100644
--- a/src/transformers/models/chameleon/modeling_chameleon.py
+++ b/src/transformers/models/chameleon/modeling_chameleon.py
@@ -815,7 +815,7 @@ class ChameleonPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_param_buffer_assignment = False
_supports_flex_attn = True
_supports_attention_backend = True
diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py
index b1378f55175b..3dbb6f5ecce3 100644
--- a/src/transformers/models/codegen/modeling_codegen.py
+++ b/src/transformers/models/codegen/modeling_codegen.py
@@ -287,7 +287,7 @@ class CodeGenPreTrainedModel(PreTrainedModel):
_no_split_modules = ["CodeGenBlock"]
_skip_keys_device_placement = "past_key_values"
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py
index 1fb91bccaa61..d299a0f192dc 100644
--- a/src/transformers/models/cohere/modeling_cohere.py
+++ b/src/transformers/models/cohere/modeling_cohere.py
@@ -345,7 +345,7 @@ class CoherePreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": CohereDecoderLayer,
diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py
index f3dc518f9246..b3c7c1b10fb8 100644
--- a/src/transformers/models/cohere2/modeling_cohere2.py
+++ b/src/transformers/models/cohere2/modeling_cohere2.py
@@ -322,7 +322,7 @@ class Cohere2PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Cohere2DecoderLayer,
diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py
index 31e9ff368978..16d2292cdaeb 100644
--- a/src/transformers/models/csm/modeling_csm.py
+++ b/src/transformers/models/csm/modeling_csm.py
@@ -371,7 +371,7 @@ class CsmPreTrainedModel(PreTrainedModel):
# does not because of Mimi codec model
# _supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": CsmDecoderLayer,
diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py
index 470161275244..688795468d97 100644
--- a/src/transformers/models/csm/modular_csm.py
+++ b/src/transformers/models/csm/modular_csm.py
@@ -134,7 +134,7 @@ class CsmPreTrainedModel(PreTrainedModel):
# does not because of Mimi codec model
# _supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": CsmDecoderLayer,
diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py
index 86b4944f08f1..ee5ec65f86bc 100644
--- a/src/transformers/models/dbrx/modeling_dbrx.py
+++ b/src/transformers/models/dbrx/modeling_dbrx.py
@@ -810,7 +810,7 @@ class DbrxPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
def _init_weights(self, module: nn.Module):
std = self.config.initializer_range
diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py
index 1b33296d7dde..6a492e937a51 100755
--- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py
+++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py
@@ -453,7 +453,7 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
is_parallelizable = True
supports_gradient_checkpointing = True
- _supports_static_cache = False
+ _can_compile_fullgraph = False
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py
index 595953fd6c18..038a13be9faa 100644
--- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py
+++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py
@@ -459,7 +459,7 @@ class DeepseekV2PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": DeepseekV2DecoderLayer,
diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
index 05171a8359d7..ba50224f1edd 100644
--- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
+++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
@@ -498,7 +498,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": DeepseekV3DecoderLayer,
diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py
index 2bf05cf683ce..4ff33698583b 100644
--- a/src/transformers/models/dia/modeling_dia.py
+++ b/src/transformers/models/dia/modeling_dia.py
@@ -67,7 +67,7 @@ class DiaPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
main_input_name = "input_ids"
_no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py
index 8c84d936c543..da50046b4ba0 100644
--- a/src/transformers/models/dia/modular_dia.py
+++ b/src/transformers/models/dia/modular_dia.py
@@ -62,7 +62,7 @@ class DiaPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
main_input_name = "input_ids"
_no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py
index 92badf62f267..64f22f410f13 100644
--- a/src/transformers/models/diffllama/modeling_diffllama.py
+++ b/src/transformers/models/diffllama/modeling_diffllama.py
@@ -534,7 +534,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = False
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = False
_can_record_outputs = {
"hidden_states": DiffLlamaDecoderLayer,
diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py
index 21b2794c03bd..d8823dc7f219 100644
--- a/src/transformers/models/doge/modeling_doge.py
+++ b/src/transformers/models/doge/modeling_doge.py
@@ -494,7 +494,7 @@ class DogePreTrainedModel(PreTrainedModel):
_supports_flash_attn = False
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = False
+ _can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(DogeCDMoE, index=1),
diff --git a/src/transformers/models/doge/modular_doge.py b/src/transformers/models/doge/modular_doge.py
index a3d1b4f9bf32..5f7f582f0ce0 100644
--- a/src/transformers/models/doge/modular_doge.py
+++ b/src/transformers/models/doge/modular_doge.py
@@ -564,7 +564,7 @@ def forward(
class DogePreTrainedModel(LlamaPreTrainedModel):
_supports_flash_attn = False
- _supports_static_cache = False
+ _can_compile_fullgraph = False
_can_record_outputs = {
"router_logits": OutputRecorder(DogeCDMoE, index=1),
"hidden_states": DogeDecoderLayer,
diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py
index 06df98b83568..8815790b5efc 100644
--- a/src/transformers/models/dots1/modeling_dots1.py
+++ b/src/transformers/models/dots1/modeling_dots1.py
@@ -418,7 +418,7 @@ class Dots1PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Dots1DecoderLayer,
diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py
index fd7cbf39e183..1c6cf72f13c9 100644
--- a/src/transformers/models/emu3/modeling_emu3.py
+++ b/src/transformers/models/emu3/modeling_emu3.py
@@ -1098,7 +1098,7 @@ class Emu3PreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_param_buffer_assignment = False
_supports_flex_attn = True
_supports_attention_backend = True
@@ -1320,7 +1320,6 @@ def forward(
class Emu3Model(Emu3PreTrainedModel):
_checkpoint_conversion_mapping = {"text_model.model": "text_model"}
- _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
def __init__(self, config):
super().__init__(config)
@@ -1463,7 +1462,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
"^vqmodel": "model.vqmodel",
"^text_model.lm_head": "lm_head",
}
- _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
def __init__(self, config):
super().__init__(config)
diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py
index 580bf670e3ca..6e18bdbc22f5 100644
--- a/src/transformers/models/emu3/modular_emu3.py
+++ b/src/transformers/models/emu3/modular_emu3.py
@@ -902,7 +902,6 @@ def forward(**super_kwargs):
class Emu3Model(Emu3PreTrainedModel):
_checkpoint_conversion_mapping = {"text_model.model": "text_model"}
- _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
def __init__(self, config):
super().__init__(config)
@@ -1045,7 +1044,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
"^vqmodel": "model.vqmodel",
"^text_model.lm_head": "lm_head",
}
- _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
def __init__(self, config):
super().__init__(config)
diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py
index 0d98583e1b44..507d6109ff09 100644
--- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py
+++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py
@@ -311,7 +311,7 @@ class Ernie4_5PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Ernie4_5DecoderLayer,
diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py
index 630162b7dc5c..0a93578e1c4f 100644
--- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py
+++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py
@@ -473,7 +473,7 @@ class Ernie4_5_MoEPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(Ernie4_5_MoESparseMoeBlock, index=1),
diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py
index 4033e2c14d58..5cd2bd505865 100644
--- a/src/transformers/models/falcon/modeling_falcon.py
+++ b/src/transformers/models/falcon/modeling_falcon.py
@@ -643,7 +643,7 @@ class FalconPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py
index 287c9b3013cc..ea25f6e60e36 100644
--- a/src/transformers/models/gemma/modeling_gemma.py
+++ b/src/transformers/models/gemma/modeling_gemma.py
@@ -314,7 +314,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": GemmaDecoderLayer,
diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py
index 6db1b1f7bbfb..819f15f958c0 100644
--- a/src/transformers/models/gemma2/modeling_gemma2.py
+++ b/src/transformers/models/gemma2/modeling_gemma2.py
@@ -344,7 +344,7 @@ class Gemma2PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Gemma2DecoderLayer,
diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py
index 394e38002171..90ec650c2a5b 100644
--- a/src/transformers/models/gemma3/modeling_gemma3.py
+++ b/src/transformers/models/gemma3/modeling_gemma3.py
@@ -434,7 +434,7 @@ class Gemma3PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Gemma3DecoderLayer,
diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py
index 1411cccef9a9..eb94159d30f7 100644
--- a/src/transformers/models/gemma3n/modeling_gemma3n.py
+++ b/src/transformers/models/gemma3n/modeling_gemma3n.py
@@ -1490,7 +1490,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Gemma3nTextDecoderLayer,
diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py
index 72733a80f769..00ee47bdd2e2 100644
--- a/src/transformers/models/glm/modeling_glm.py
+++ b/src/transformers/models/glm/modeling_glm.py
@@ -331,7 +331,7 @@ class GlmPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": GlmDecoderLayer,
diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py
index e4dd64102d56..1c1980f13ad8 100644
--- a/src/transformers/models/glm4/modeling_glm4.py
+++ b/src/transformers/models/glm4/modeling_glm4.py
@@ -335,7 +335,7 @@ class Glm4PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Glm4DecoderLayer,
diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py
index 31ad8ede952f..b4f16c0f2f67 100644
--- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py
+++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py
@@ -403,7 +403,7 @@ class Glm4MoePreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = False
+ _can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Glm4MoeDecoderLayer,
diff --git a/src/transformers/models/glm4_moe/modular_glm4_moe.py b/src/transformers/models/glm4_moe/modular_glm4_moe.py
index 509fc39d391d..1f6628d938da 100644
--- a/src/transformers/models/glm4_moe/modular_glm4_moe.py
+++ b/src/transformers/models/glm4_moe/modular_glm4_moe.py
@@ -310,7 +310,7 @@ class Glm4MoeDecoderLayer(DeepseekV3DecoderLayer):
class Glm4MoePreTrainedModel(DeepseekV3PreTrainedModel):
- _supports_static_cache = False
+ _can_compile_fullgraph = False
class Glm4MoeModel(DeepseekV3Model):
diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py
index 41e37b3e1a7a..82897df659e2 100644
--- a/src/transformers/models/glm4v/modeling_glm4v.py
+++ b/src/transformers/models/glm4v/modeling_glm4v.py
@@ -407,7 +407,7 @@ class Glm4vPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
def _init_weights(self, module):
diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py
index 99959cd74bd5..3c2ee44dd569 100644
--- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py
+++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py
@@ -283,7 +283,7 @@ class GotOcr2PreTrainedModel(PreTrainedModel):
_supports_flash_attn = False
_supports_sdpa = False
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_flex_attn = False
_supports_attention_backend = True
diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py
index c853d80e4ae5..80442af9110d 100644
--- a/src/transformers/models/gpt2/modeling_gpt2.py
+++ b/src/transformers/models/gpt2/modeling_gpt2.py
@@ -563,7 +563,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_attention_backend = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
index 7d655bd0e6a5..89e6f7182a75 100755
--- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py
+++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
@@ -477,7 +477,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTNeoBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
- _supports_static_cache = False # TODO: needs a HybridCache
+ _can_compile_fullgraph = False # TODO: needs a HybridCache
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
index 15cc664d74b8..52fb7b595ac1 100755
--- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py
+++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@@ -364,7 +364,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": GPTNeoXDecoderLayer,
diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
index 9e1859e794b8..e80f7880239f 100755
--- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
+++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
@@ -48,7 +48,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTNeoXJapaneseLayer"]
_skip_keys_device_placement = "past_key_values"
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py
index 43822682df35..8c388ac77b44 100644
--- a/src/transformers/models/gptj/modeling_gptj.py
+++ b/src/transformers/models/gptj/modeling_gptj.py
@@ -472,7 +472,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTJBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_param_buffer_assignment = False
def __init__(self, *inputs, **kwargs):
diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py
index 8bebef03c225..610bf5ef521f 100644
--- a/src/transformers/models/granite/modeling_granite.py
+++ b/src/transformers/models/granite/modeling_granite.py
@@ -309,7 +309,7 @@ class GranitePreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": GraniteDecoderLayer,
diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py
index bf72cc85da70..bf9c24a28a24 100644
--- a/src/transformers/models/granitemoe/modeling_granitemoe.py
+++ b/src/transformers/models/granitemoe/modeling_granitemoe.py
@@ -592,7 +592,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py
index 8b3f3d1dccd9..93fffb76116d 100644
--- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py
+++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py
@@ -1208,7 +1208,7 @@ class GraniteMoeHybridPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_is_stateful = True
def _init_weights(self, module):
diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py
index 83f78ae32774..3493f42c87f7 100644
--- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py
+++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py
@@ -510,7 +510,7 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py
index f15fbd48ddad..39ed4c8bc903 100644
--- a/src/transformers/models/helium/modeling_helium.py
+++ b/src/transformers/models/helium/modeling_helium.py
@@ -316,7 +316,7 @@ class HeliumPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": HeliumDecoderLayer,
diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py
index 741886c1425b..ac8b7776c564 100644
--- a/src/transformers/models/idefics/modeling_idefics.py
+++ b/src/transformers/models/idefics/modeling_idefics.py
@@ -880,7 +880,7 @@ class IdeficsPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flash_attn = True
- _supports_static_cache = False # IDEFICS cannot compile due to dynamic control flow when checking inputs
+ _can_compile_fullgraph = False # IDEFICS cannot compile due to dynamic control flow when checking inputs
_supports_attention_backend = True
def _init_weights(self, module):
diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py
index b88c003660b4..bcafeeec1e73 100644
--- a/src/transformers/models/instructblip/modeling_instructblip.py
+++ b/src/transformers/models/instructblip/modeling_instructblip.py
@@ -340,7 +340,7 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_no_split_modules = [
"InstructBlipQFormerEmbeddings",
@@ -1354,7 +1354,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati
config: InstructBlipConfig
main_input_name = "pixel_values"
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
def __init__(self, config: InstructBlipConfig):
diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py
index cec919825324..8e098183e274 100644
--- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py
+++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py
@@ -827,7 +827,7 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_no_split_modules = [
"InstructBlipVideoQFormerEmbeddings",
@@ -1360,7 +1360,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel
config: InstructBlipVideoConfig
main_input_name = "pixel_values"
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
def __init__(self, config: InstructBlipVideoConfig):
diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py
index 46ef56dc4600..6c0b096bc43e 100644
--- a/src/transformers/models/internvl/modeling_internvl.py
+++ b/src/transformers/models/internvl/modeling_internvl.py
@@ -524,7 +524,7 @@ class InternVLPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_flex_attn = True
_supports_attention_backend = True
diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py
index 2a2257dfecf8..f211db293835 100644
--- a/src/transformers/models/janus/modeling_janus.py
+++ b/src/transformers/models/janus/modeling_janus.py
@@ -63,7 +63,7 @@ class JanusPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_param_buffer_assignment = False
def _init_weights(self, module):
@@ -1123,7 +1123,7 @@ def forward(
class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def __init__(self, config: JanusConfig):
super().__init__(config)
diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py
index 81e63901831e..6cd7c9d47816 100644
--- a/src/transformers/models/janus/modular_janus.py
+++ b/src/transformers/models/janus/modular_janus.py
@@ -390,7 +390,7 @@ class JanusPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_param_buffer_assignment = False
def _init_weights(self, module):
@@ -982,7 +982,7 @@ def forward(
class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def __init__(self, config: JanusConfig):
super().__init__(config)
diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py
index 0d383769d1c4..7dd665d91289 100644
--- a/src/transformers/models/lfm2/modeling_lfm2.py
+++ b/src/transformers/models/lfm2/modeling_lfm2.py
@@ -542,7 +542,7 @@ class Lfm2PreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = False
+ _can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Lfm2DecoderLayer,
diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py
index 338e6ec5242d..c53a63b45e94 100644
--- a/src/transformers/models/lfm2/modular_lfm2.py
+++ b/src/transformers/models/lfm2/modular_lfm2.py
@@ -401,7 +401,7 @@ def forward(
class Lfm2PreTrainedModel(LlamaPreTrainedModel):
- _supports_static_cache = False
+ _can_compile_fullgraph = False
def _init_weights(self, module):
std = self.config.initializer_range
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 4bab75a87ce4..043857744e4e 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -315,7 +315,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": LlamaDecoderLayer,
diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py
index 85aeb70ce38b..53d9367b7c18 100644
--- a/src/transformers/models/llama4/modeling_llama4.py
+++ b/src/transformers/models/llama4/modeling_llama4.py
@@ -437,7 +437,7 @@ class Llama4PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
def _init_weights(self, module):
diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py
index 032751a4e116..b9576294b21e 100644
--- a/src/transformers/models/llava/modeling_llava.py
+++ b/src/transformers/models/llava/modeling_llava.py
@@ -121,7 +121,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_flex_attn = True
_supports_attention_backend = True
diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py
index 7cbad1b98093..94f03925b8b0 100644
--- a/src/transformers/models/llava_next/modeling_llava_next.py
+++ b/src/transformers/models/llava_next/modeling_llava_next.py
@@ -232,7 +232,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_flex_attn = True
_supports_attention_backend = True
diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py
index 7721d760eacc..dce37a4dd985 100644
--- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py
@@ -173,7 +173,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_flex_attn = True
_supports_attention_backend = True
diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py
index ea5ca1e5ea63..41c39d26f38b 100644
--- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py
+++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py
@@ -286,7 +286,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_flex_attn = True
_supports_attention_backend = True
diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py
index d4e29c619b11..261277f492b1 100644
--- a/src/transformers/models/longt5/modeling_longt5.py
+++ b/src/transformers/models/longt5/modeling_longt5.py
@@ -1250,7 +1250,7 @@ class LongT5PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["LongT5Block"]
- _supports_static_cache = False # TODO: @raushan more involved due to local/global attn
+ _can_compile_fullgraph = False # TODO: @raushan more involved due to local/global attn
@property
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs
diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py
index 6790872107a1..01a26c5ae61c 100755
--- a/src/transformers/models/m2m_100/modeling_m2m_100.py
+++ b/src/transformers/models/m2m_100/modeling_m2m_100.py
@@ -525,7 +525,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
_supports_flex_attn = True
# Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model
- _supports_static_cache = False
+ _can_compile_fullgraph = False
def _init_weights(self, module):
std = self.config.init_std
diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py
index 5f988b3a82ee..c328d634029c 100755
--- a/src/transformers/models/marian/modeling_marian.py
+++ b/src/transformers/models/marian/modeling_marian.py
@@ -467,7 +467,7 @@ class MarianPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]):
std = self.config.init_std
diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py
index 0a6880415f9e..93f7b2bef8e4 100755
--- a/src/transformers/models/mbart/modeling_mbart.py
+++ b/src/transformers/models/mbart/modeling_mbart.py
@@ -492,7 +492,7 @@ class MBartPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
std = self.config.init_std
diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py
index 0be377794d0a..260ea6f7ce2e 100644
--- a/src/transformers/models/mimi/modeling_mimi.py
+++ b/src/transformers/models/mimi/modeling_mimi.py
@@ -1376,7 +1376,7 @@ class MimiPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py
index 6923fdc91abb..a90d1e97e958 100644
--- a/src/transformers/models/minimax/modeling_minimax.py
+++ b/src/transformers/models/minimax/modeling_minimax.py
@@ -588,8 +588,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
- # Note: only supports MiniMaxCache
- _supports_static_cache = False
+ _can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(MiniMaxSparseMoeBlock, index=1),
diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py
index 423ae27717c4..36c5f0a29184 100644
--- a/src/transformers/models/minimax/modular_minimax.py
+++ b/src/transformers/models/minimax/modular_minimax.py
@@ -470,8 +470,7 @@ def forward(
class MiniMaxPreTrainedModel(MixtralPreTrainedModel):
- # Note: only supports MiniMaxCache
- _supports_static_cache = False
+ _can_compile_fullgraph = False
_can_record_outputs = {
"router_logits": OutputRecorder(MiniMaxSparseMoeBlock, index=1),
"hidden_states": MiniMaxDecoderLayer,
diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py
index 1189f1b34d38..aea8532b53b6 100644
--- a/src/transformers/models/mistral/modeling_mistral.py
+++ b/src/transformers/models/mistral/modeling_mistral.py
@@ -260,7 +260,7 @@ class MistralPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": MistralDecoderLayer,
diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py
index d1a9c83f9dad..3fbc2db078d6 100644
--- a/src/transformers/models/mistral3/modeling_mistral3.py
+++ b/src/transformers/models/mistral3/modeling_mistral3.py
@@ -186,7 +186,7 @@ class Mistral3PreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_flex_attn = True
_supports_attention_backend = True
diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py
index ec185002691d..72e3df7f1914 100644
--- a/src/transformers/models/mixtral/modeling_mixtral.py
+++ b/src/transformers/models/mixtral/modeling_mixtral.py
@@ -388,7 +388,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(MixtralSparseMoeBlock, index=1),
diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py
index de02a2a833bc..c4a7b5b2df6c 100644
--- a/src/transformers/models/mixtral/modular_mixtral.py
+++ b/src/transformers/models/mixtral/modular_mixtral.py
@@ -277,7 +277,7 @@ class MixtralRotaryEmbedding(MistralRotaryEmbedding):
class MixtralPreTrainedModel(MistralPreTrainedModel):
- _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_can_record_outputs = {
"router_logits": OutputRecorder(MixtralSparseMoeBlock, index=1),
"hidden_states": MixtralDecoderLayer,
diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py
index 7c126f42f1e5..e15e7f9e2683 100644
--- a/src/transformers/models/mllama/modeling_mllama.py
+++ b/src/transformers/models/mllama/modeling_mllama.py
@@ -850,7 +850,7 @@ class MllamaPreTrainedModel(PreTrainedModel):
"MllamaSelfAttentionDecoderLayer",
]
- _supports_static_cache = False # static cache cannot have different shapes for each layer
+ _can_compile_fullgraph = False # static cache cannot have different shapes for each layer
_supports_sdpa = True
_supports_flash_attn = True
_supports_flex_attn = True
@@ -1449,7 +1449,7 @@ def forward(
)
class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin):
config: MllamaTextConfig
- _supports_static_cache = True # only the LLM without cross attn can do compile
+ _can_compile_fullgraph = True # only the LLM without cross attn can do compile
base_model_prefix = "language_model"
_tied_weights_keys = ["lm_head.weight"]
diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py
index c3d90771e11f..0b1d572a1f79 100644
--- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py
+++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py
@@ -224,7 +224,7 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = False
_supports_gradient_checkpointing = True
- _supports_static_cache = False
+ _can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": ModernBertDecoderLayer,
diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py
index d215ccbf0bc7..3b6b936e15e7 100644
--- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py
+++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py
@@ -401,7 +401,7 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = False
_supports_gradient_checkpointing = True
- _supports_static_cache = False
+ _can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": ModernBertDecoderLayer,
diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py
index 9b229e4074c0..9bd948c99269 100644
--- a/src/transformers/models/moonshine/modeling_moonshine.py
+++ b/src/transformers/models/moonshine/modeling_moonshine.py
@@ -462,7 +462,7 @@ class MoonshinePreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
# TODO arthur, how do we separate when it cross / self coming from different layer?
def _init_weights(self, module):
diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py
index 9706d99d7cd5..85c633502b65 100644
--- a/src/transformers/models/moonshine/modular_moonshine.py
+++ b/src/transformers/models/moonshine/modular_moonshine.py
@@ -497,7 +497,7 @@ class MoonshinePreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
# TODO arthur, how do we separate when it cross / self coming from different layer?
def _init_weights(self, module):
diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py
index a9d0fd9781a7..6dbcb202f9a7 100644
--- a/src/transformers/models/mt5/modeling_mt5.py
+++ b/src/transformers/models/mt5/modeling_mt5.py
@@ -757,7 +757,7 @@ class MT5PreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_no_split_modules = ["MT5Block"]
_keep_in_fp32_modules = ["wo"]
diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py
index df64eec95c31..5fffc6bb832f 100644
--- a/src/transformers/models/nemotron/modeling_nemotron.py
+++ b/src/transformers/models/nemotron/modeling_nemotron.py
@@ -589,7 +589,7 @@ class NemotronPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
std = self.config.initializer_range
diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py
index 817341cc258f..3eb83da1818a 100644
--- a/src/transformers/models/olmo/modeling_olmo.py
+++ b/src/transformers/models/olmo/modeling_olmo.py
@@ -294,7 +294,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": OlmoDecoderLayer,
diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py
index b589364df59a..4292c235c32c 100644
--- a/src/transformers/models/olmo2/modeling_olmo2.py
+++ b/src/transformers/models/olmo2/modeling_olmo2.py
@@ -299,7 +299,7 @@ class Olmo2PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Olmo2DecoderLayer,
diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py
index 8caff4d9abf1..38b538fdf301 100644
--- a/src/transformers/models/olmoe/modeling_olmoe.py
+++ b/src/transformers/models/olmoe/modeling_olmoe.py
@@ -706,7 +706,7 @@ class OlmoePreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
def _init_weights(self, module):
std = self.config.initializer_range
diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py
index 2fe9a146774c..ff5e8dfa010d 100644
--- a/src/transformers/models/opt/modeling_opt.py
+++ b/src/transformers/models/opt/modeling_opt.py
@@ -313,7 +313,7 @@ class OPTPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
std = self.config.init_std
diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py
index 581269653c31..2d82dccc1866 100644
--- a/src/transformers/models/paligemma/modeling_paligemma.py
+++ b/src/transformers/models/paligemma/modeling_paligemma.py
@@ -114,7 +114,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["PaliGemmaMultiModalProjector"]
_skip_keys_device_placement = "past_key_values"
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py
index 05e34da4d2b3..a6531b76c6df 100755
--- a/src/transformers/models/pegasus/modeling_pegasus.py
+++ b/src/transformers/models/pegasus/modeling_pegasus.py
@@ -458,7 +458,7 @@ class PegasusPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
std = self.config.init_std
diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py
index ee48a02b04cf..f3be1886c1b9 100755
--- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py
+++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py
@@ -758,7 +758,7 @@ class PegasusXPreTrainedModel(PreTrainedModel):
_supports_sdpa = False
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
std = self.config.init_std
diff --git a/src/transformers/models/perception_lm/modeling_perception_lm.py b/src/transformers/models/perception_lm/modeling_perception_lm.py
index 657aa805691b..e37d55d73342 100644
--- a/src/transformers/models/perception_lm/modeling_perception_lm.py
+++ b/src/transformers/models/perception_lm/modeling_perception_lm.py
@@ -95,7 +95,7 @@ class PerceptionLMPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_flex_attn = True
_supports_attention_backend = True
diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py
index 64f3bdd7b559..4f4155fa18c1 100644
--- a/src/transformers/models/persimmon/modeling_persimmon.py
+++ b/src/transformers/models/persimmon/modeling_persimmon.py
@@ -389,7 +389,7 @@ class PersimmonPreTrainedModel(PreTrainedModel):
_no_split_modules = ["PersimmonDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_sdpa = True
_supports_flash_attn = True
_supports_attention_backend = True
diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py
index ea77b2d4719a..ab32157b7692 100644
--- a/src/transformers/models/phi/modeling_phi.py
+++ b/src/transformers/models/phi/modeling_phi.py
@@ -299,7 +299,7 @@ class PhiPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": PhiDecoderLayer,
diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py
index c896def491e3..80d1f3c8f7b1 100644
--- a/src/transformers/models/phi3/modeling_phi3.py
+++ b/src/transformers/models/phi3/modeling_phi3.py
@@ -291,7 +291,7 @@ class Phi3PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Phi3DecoderLayer,
diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py
index 855e8b7fc1b2..3bc45787e375 100644
--- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py
+++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py
@@ -1593,7 +1593,7 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Phi4MultimodalDecoderLayer,
diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py
index 0662acf7e698..3f27b4f456e7 100644
--- a/src/transformers/models/phimoe/modeling_phimoe.py
+++ b/src/transformers/models/phimoe/modeling_phimoe.py
@@ -890,7 +890,7 @@ class PhimoePreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
def _init_weights(self, module):
std = self.config.initializer_range
diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py
index 5d080e8f0c99..e39e374dfdd8 100644
--- a/src/transformers/models/pix2struct/modeling_pix2struct.py
+++ b/src/transformers/models/pix2struct/modeling_pix2struct.py
@@ -351,7 +351,7 @@ def forward(
class Pix2StructPreTrainedModel(PreTrainedModel):
config: Pix2StructConfig
- _supports_static_cache = False
+ _can_compile_fullgraph = False
@property
def dummy_inputs(self):
diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py
index 795dfb587421..3c571de1cd6b 100644
--- a/src/transformers/models/pop2piano/modeling_pop2piano.py
+++ b/src/transformers/models/pop2piano/modeling_pop2piano.py
@@ -577,7 +577,7 @@ class Pop2PianoPreTrainedModel(PreTrainedModel):
is_parallelizable = False
supports_gradient_checkpointing = True
- _supports_static_cache = False
+ _can_compile_fullgraph = False
_no_split_modules = ["Pop2PianoBlock"]
_keep_in_fp32_modules = ["wo"]
diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py
index e8c5a1bc8a8a..4be3210863f8 100644
--- a/src/transformers/models/qwen2/modeling_qwen2.py
+++ b/src/transformers/models/qwen2/modeling_qwen2.py
@@ -263,7 +263,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Qwen2DecoderLayer,
diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
index 37c0c6da43de..feb6aad0c1d6 100644
--- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
+++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
@@ -87,7 +87,7 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = False
+ _can_compile_fullgraph = False
_supports_attention_backend = True
def _init_weights(self, module):
diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
index d40b0a073c79..d9b995f02eb3 100644
--- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
+++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
@@ -1134,7 +1134,7 @@ def get_text_config(self, decoder=False):
class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel):
config: Qwen2_5OmniConfig
- _supports_static_cache = False
+ _can_compile_fullgraph = False
def _init_weights(self, module):
# important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only
diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
index 66fb7a7c06d5..063afe30d44b 100644
--- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
+++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
@@ -326,7 +326,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
def _init_weights(self, module):
diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
index b14264f45516..f489b74df3d6 100644
--- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
+++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
@@ -660,7 +660,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
def _init_weights(self, module):
diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py
index 73f063148089..87671aa01e82 100644
--- a/src/transformers/models/qwen3/modeling_qwen3.py
+++ b/src/transformers/models/qwen3/modeling_qwen3.py
@@ -289,7 +289,7 @@ class Qwen3PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Qwen3DecoderLayer,
diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
index f37568777b92..f67a0738a4f5 100644
--- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
+++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
@@ -411,7 +411,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(Qwen3MoeSparseMoeBlock, index=1),
diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py
index 92a3205a7b26..edc88c9b6ac4 100644
--- a/src/transformers/models/smollm3/modeling_smollm3.py
+++ b/src/transformers/models/smollm3/modeling_smollm3.py
@@ -293,7 +293,7 @@ class SmolLM3PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": SmolLM3DecoderLayer,
diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py
index b7a61f3360de..fe3624bc466b 100755
--- a/src/transformers/models/stablelm/modeling_stablelm.py
+++ b/src/transformers/models/stablelm/modeling_stablelm.py
@@ -621,7 +621,7 @@ class StableLmPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
std = self.config.initializer_range
diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py
index 9e574d534980..edfb7210a03d 100644
--- a/src/transformers/models/starcoder2/modeling_starcoder2.py
+++ b/src/transformers/models/starcoder2/modeling_starcoder2.py
@@ -297,7 +297,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Starcoder2DecoderLayer,
diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py
index d71dffb98782..ab906d07612c 100644
--- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py
+++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py
@@ -766,7 +766,7 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel):
base_model_prefix = "switch_transformers"
supports_gradient_checkpointing = True
- _supports_static_cache = False
+ _can_compile_fullgraph = False
_no_split_modules = ["SwitchTransformersBlock"]
@property
diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py
index b5ff699f69a3..247437dee64a 100644
--- a/src/transformers/models/t5/modeling_t5.py
+++ b/src/transformers/models/t5/modeling_t5.py
@@ -771,7 +771,7 @@ class T5PreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_no_split_modules = ["T5Block"]
_keep_in_fp32_modules = ["wo"]
diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py
index c2fdbf5fc7d4..6c6b94bb1fcb 100644
--- a/src/transformers/models/t5gemma/modeling_t5gemma.py
+++ b/src/transformers/models/t5gemma/modeling_t5gemma.py
@@ -585,7 +585,7 @@ class T5GemmaPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": T5GemmaDecoderLayer,
diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py
index 70a174474b7d..f29af9045589 100644
--- a/src/transformers/models/udop/modeling_udop.py
+++ b/src/transformers/models/udop/modeling_udop.py
@@ -255,7 +255,7 @@ class UdopPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
- _supports_static_cache = False
+ _can_compile_fullgraph = False
_keep_in_fp32_modules = ["wo"]
def _init_weights(self, module):
diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py
index 47b11acfd8dc..ef37ce2045e2 100644
--- a/src/transformers/models/umt5/modeling_umt5.py
+++ b/src/transformers/models/umt5/modeling_umt5.py
@@ -508,7 +508,7 @@ class UMT5PreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_no_split_modules = ["UMT5Block"]
_keep_in_fp32_modules = ["wo"]
diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py
index aea725686822..befa350b907c 100644
--- a/src/transformers/models/video_llava/modeling_video_llava.py
+++ b/src/transformers/models/video_llava/modeling_video_llava.py
@@ -135,7 +135,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
def _init_weights(self, module):
diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py
index d3c263807bea..f6a6c5156890 100644
--- a/src/transformers/models/vipllava/modeling_vipllava.py
+++ b/src/transformers/models/vipllava/modeling_vipllava.py
@@ -122,7 +122,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_flex_attn = True
_supports_attention_backend = True
diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py
index b2350310a8ae..ae949a457969 100644
--- a/src/transformers/models/voxtral/modeling_voxtral.py
+++ b/src/transformers/models/voxtral/modeling_voxtral.py
@@ -236,7 +236,7 @@ class VoxtralPreTrainedModel(PreTrainedModel):
_supports_flex_attn = True
_supports_cache_class = True
_supports_attention_backend = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
# important: this ported version of Voxtral isn't meant for training from scratch - only
diff --git a/src/transformers/models/voxtral/modular_voxtral.py b/src/transformers/models/voxtral/modular_voxtral.py
index fdb9862ad5f4..a3cb8c3ed00d 100644
--- a/src/transformers/models/voxtral/modular_voxtral.py
+++ b/src/transformers/models/voxtral/modular_voxtral.py
@@ -47,7 +47,7 @@ class VoxtralPreTrainedModel(Qwen2AudioPreTrainedModel):
_supports_flex_attn = True
_supports_cache_class = True
_supports_attention_backend = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
_supports_attention_backend = True
diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py
index 7b74c4c0b853..b367b82fcda4 100644
--- a/src/transformers/models/whisper/modeling_whisper.py
+++ b/src/transformers/models/whisper/modeling_whisper.py
@@ -553,7 +553,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flex_attn = True
- _supports_static_cache = True
+ _can_compile_fullgraph = True
def _init_weights(self, module):
std = self.config.init_std
diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py
index 11eb382bda99..cc48df55f5da 100644
--- a/src/transformers/utils/auto_docstring.py
+++ b/src/transformers/utils/auto_docstring.py
@@ -965,8 +965,9 @@ class ClassAttrs:
_supports_flex_attn = r"""
Whether the model's attention implementation supports FlexAttention.
"""
- _supports_static_cache = r"""
- Whether the model supports a `StaticCache` instance as `past_key_values`.
+ _can_compile_fullgraph = r"""
+ Whether the model can `torch.compile` fullgraph without graph breaks. Models will auto-compile if this flag is set to `True`
+ in inference, if a compilable cache is used.
"""
_supports_attention_backend = r"""
Whether the model supports attention interface functions. This flag signal that the model can be used as an efficient backend in TGI and vLLM.
diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py
index fab1672b5c86..07614209b2ef 100644
--- a/tests/generation/test_utils.py
+++ b/tests/generation/test_utils.py
@@ -1758,7 +1758,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache.
"""
for model_class in self.all_generative_model_classes:
- if not model_class._supports_static_cache:
+ if not model_class._can_compile_fullgraph:
self.skipTest(reason="This model does not support the static cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
@@ -1978,7 +1978,7 @@ def test_generate_with_static_cache(self):
"""
set_model_tester_for_less_flaky_test(self)
for model_class in self.all_generative_model_classes:
- if not model_class._supports_static_cache:
+ if not model_class._can_compile_fullgraph:
self.skipTest(reason="This model does not support the static cache format")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
@@ -2081,8 +2081,8 @@ def test_generate_compile_model_forward(self):
set_model_tester_for_less_flaky_test(self)
for model_class in self.all_generative_model_classes:
# 1. Test exclusion criteria
- if not model_class._supports_static_cache:
- self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
+ if not model_class._can_compile_fullgraph:
+ self.skipTest("This model doesn't support compilation without graph breaks")
# 2. Prepares two sets of inputs
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=4)
@@ -2195,8 +2195,8 @@ def test_generate_compilation_all_outputs(self):
In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered.
"""
for model_class in self.all_generative_model_classes:
- if not model_class._supports_static_cache:
- self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
+ if not model_class._can_compile_fullgraph:
+ self.skipTest("This model doesn't support compilation without graph breaks")
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if self.has_attentions:
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index 5589c8cc0d61..fc41a1e6fb47 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -4400,7 +4400,7 @@ def test_custom_4d_attention_mask(self):
set_model_tester_for_less_flaky_test(self)
for model_class in self.all_generative_model_classes:
- if not model_class._supports_static_cache:
+ if not model_class._can_compile_fullgraph:
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
set_config_for_less_flaky_test(config)