| |
| |
| |
| |
| @@ -294,7 +294,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel): |
| _supports_flex_attn = True |
| _supports_cache_class = True |
| _supports_quantized_cache = True |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": MyNewModel2DecoderLayer, |
| |
| |
| |
| |
| @@ -94,7 +94,7 @@ class NewTaskModelPreTrainedModel(PreTrainedModel): |
| _skip_keys_device_placement = "past_key_values" |
| _supports_cache_class = True |
| _supports_quantized_cache = True |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -293,7 +293,7 @@ class SuperPreTrainedModel(PreTrainedModel): |
| _supports_flex_attn = True |
| _supports_cache_class = True |
| _supports_quantized_cache = True |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": SuperDecoderLayer, |
| |
| |
| |
| |
| @@ -2042,7 +2042,7 @@ def _prepare_cache_for_generation( |
| ) |
| if generation_config.cache_implementation is not None: |
| if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: |
| - if generation_config.cache_implementation == "static" and not self._supports_static_cache: |
| + if generation_config.cache_implementation == "static" and not self._can_compile_fullgraph: |
| raise ValueError( |
| "This model does not support `cache_implementation='static'`. Please check the following " |
| "issue: https://github.com/huggingface/transformers/issues/28981" |
| @@ -2198,7 +2198,8 @@ def _valid_auto_compile_criteria(self, model_kwargs: dict, generation_config: Ge |
| using_compilable_cache = ( |
| isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable |
| ) |
| - can_compile = valid_hardware and using_compilable_cache and self._supports_static_cache |
| + # TODO @raushan `self._can_compile_fullgraph` can be removed and inferred from model arch (e.g. MoE doesn't support compile) |
| + can_compile = valid_hardware and using_compilable_cache and self._can_compile_fullgraph |
| |
| # Exception 1: Some quantization methods do not support compilation |
| if getattr(self, "hf_quantizer", None) is not None: |
| |
| |
| |
| |
| @@ -2062,8 +2062,7 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH |
| # Flex Attention support |
| _supports_flex_attn = False |
| |
| - # Has support `torch.compile(fullgraph=True)` |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| |
| # A tensor parallel plan to be applied to the model when TP is enabled. For |
| # top-level models, this attribute is currently defined in respective model |
| |
| |
| |
| |
| @@ -317,7 +317,7 @@ class ArceePreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": ArceeDecoderLayer, |
| |
| |
| |
| |
| @@ -664,7 +664,7 @@ class AriaPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| - _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) |
| + _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing) |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": AriaTextDecoderLayer, |
| |
| |
| |
| |
| @@ -1312,7 +1312,7 @@ def _init_weights(self, module): |
| class AriaPreTrainedModel(LlamaPreTrainedModel): |
| config: AriaConfig |
| base_model_prefix = "" |
| - _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) |
| + _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing) |
| _supports_attention_backend = True |
| |
| def _init_weights(self, module): |
| |
| |
| |
| |
| @@ -96,7 +96,7 @@ class AyaVisionPreTrainedModel(PreTrainedModel): |
| |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| _supports_flex_attn = True |
| _supports_attention_backend = True |
| |
| |
| |
| |
| |
| @@ -90,7 +90,7 @@ def pixel_shuffle(self, image_features): # B, S, D |
| |
| |
| class AyaVisionPreTrainedModel(LlavaPreTrainedModel): |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| |
| def _init_weights(self, module): |
| std = ( |
| |
| |
| |
| |
| @@ -493,7 +493,7 @@ class BartPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| std = self.config.init_std |
| |
| |
| |
| |
| @@ -1565,7 +1565,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): |
| _skip_keys_device_placement = "past_key_values" |
| _supports_param_buffer_assignment = False |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| std = self.config.init_std |
| |
| |
| |
| |
| @@ -347,7 +347,7 @@ class BioGptPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| """Initialize the weights""" |
| |
| |
| |
| |
| @@ -172,7 +172,7 @@ class BioGptPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| """Initialize the weights""" |
| |
| |
| |
| |
| @@ -312,7 +312,7 @@ class BitNetPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": BitNetDecoderLayer, |
| |
| |
| |
| |
| @@ -458,7 +458,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| std = self.config.init_std |
| |
| |
| |
| |
| @@ -451,7 +451,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| std = self.config.init_std |
| |
| |
| |
| |
| @@ -1831,7 +1831,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): |
| config: Blip2Config |
| main_input_name = "pixel_values" |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _keep_in_fp32_modules = ["query_tokens", "qformer"] |
| _supports_flash_attn = False # because self.qformer does not support FA2 |
| |
| |
| |
| |
| |
| @@ -434,7 +434,7 @@ class BloomPreTrainedModel(PreTrainedModel): |
| _no_split_modules = ["BloomBlock"] |
| _skip_keys_device_placement = "past_key_values" |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def __init__(self, *inputs, **kwargs): |
| super().__init__(*inputs, **kwargs) |
| |
| |
| |
| |
| @@ -815,7 +815,7 @@ class ChameleonPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_param_buffer_assignment = False |
| _supports_flex_attn = True |
| _supports_attention_backend = True |
| |
| |
| |
| |
| @@ -287,7 +287,7 @@ class CodeGenPreTrainedModel(PreTrainedModel): |
| _no_split_modules = ["CodeGenBlock"] |
| _skip_keys_device_placement = "past_key_values" |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def __init__(self, *inputs, **kwargs): |
| super().__init__(*inputs, **kwargs) |
| |
| |
| |
| |
| @@ -345,7 +345,7 @@ class CoherePreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": CohereDecoderLayer, |
| |
| |
| |
| |
| @@ -322,7 +322,7 @@ class Cohere2PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Cohere2DecoderLayer, |
| |
| |
| |
| |
| @@ -371,7 +371,7 @@ class CsmPreTrainedModel(PreTrainedModel): |
| # does not because of Mimi codec model |
| # _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": CsmDecoderLayer, |
| |
| |
| |
| |
| @@ -134,7 +134,7 @@ class CsmPreTrainedModel(PreTrainedModel): |
| # does not because of Mimi codec model |
| # _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": CsmDecoderLayer, |
| |
| |
| |
| |
| @@ -810,7 +810,7 @@ class DbrxPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| |
| def _init_weights(self, module: nn.Module): |
| std = self.config.initializer_range |
| |
| |
| |
| |
| @@ -453,7 +453,7 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel): |
| is_parallelizable = True |
| supports_gradient_checkpointing = True |
| |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| |
| def __init__(self, *inputs, **kwargs): |
| super().__init__(*inputs, **kwargs) |
| |
| |
| |
| |
| @@ -459,7 +459,7 @@ class DeepseekV2PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": DeepseekV2DecoderLayer, |
| |
| |
| |
| |
| @@ -498,7 +498,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": DeepseekV3DecoderLayer, |
| |
| |
| |
| |
| @@ -67,7 +67,7 @@ class DiaPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| main_input_name = "input_ids" |
| _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"] |
| |
| |
| |
| |
| |
| @@ -62,7 +62,7 @@ class DiaPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| main_input_name = "input_ids" |
| _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"] |
| |
| |
| |
| |
| |
| @@ -534,7 +534,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = False |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = False |
| _can_record_outputs = { |
| "hidden_states": DiffLlamaDecoderLayer, |
| |
| |
| |
| |
| @@ -494,7 +494,7 @@ class DogePreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = False |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "router_logits": OutputRecorder(DogeCDMoE, index=1), |
| |
| |
| |
| |
| @@ -564,7 +564,7 @@ def forward( |
| |
| class DogePreTrainedModel(LlamaPreTrainedModel): |
| _supports_flash_attn = False |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| _can_record_outputs = { |
| "router_logits": OutputRecorder(DogeCDMoE, index=1), |
| "hidden_states": DogeDecoderLayer, |
| |
| |
| |
| |
| @@ -418,7 +418,7 @@ class Dots1PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Dots1DecoderLayer, |
| |
| |
| |
| |
| @@ -1098,7 +1098,7 @@ class Emu3PreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_param_buffer_assignment = False |
| _supports_flex_attn = True |
| _supports_attention_backend = True |
| @@ -1320,7 +1320,6 @@ def forward( |
| |
| class Emu3Model(Emu3PreTrainedModel): |
| _checkpoint_conversion_mapping = {"text_model.model": "text_model"} |
| - _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable |
| |
| def __init__(self, config): |
| super().__init__(config) |
| @@ -1463,7 +1462,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): |
| "^vqmodel": "model.vqmodel", |
| "^text_model.lm_head": "lm_head", |
| } |
| - _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable |
| |
| def __init__(self, config): |
| super().__init__(config) |
| |
| |
| |
| |
| @@ -902,7 +902,6 @@ def forward(**super_kwargs): |
| |
| class Emu3Model(Emu3PreTrainedModel): |
| _checkpoint_conversion_mapping = {"text_model.model": "text_model"} |
| - _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable |
| |
| def __init__(self, config): |
| super().__init__(config) |
| @@ -1045,7 +1044,6 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): |
| "^vqmodel": "model.vqmodel", |
| "^text_model.lm_head": "lm_head", |
| } |
| - _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable |
| |
| def __init__(self, config): |
| super().__init__(config) |
| |
| |
| |
| |
| @@ -311,7 +311,7 @@ class Ernie4_5PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Ernie4_5DecoderLayer, |
| |
| |
| |
| |
| @@ -473,7 +473,7 @@ class Ernie4_5_MoEPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "router_logits": OutputRecorder(Ernie4_5_MoESparseMoeBlock, index=1), |
| |
| |
| |
| |
| @@ -643,7 +643,7 @@ class FalconPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def __init__(self, *inputs, **kwargs): |
| super().__init__(*inputs, **kwargs) |
| |
| |
| |
| |
| @@ -314,7 +314,7 @@ class GemmaPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": GemmaDecoderLayer, |
| |
| |
| |
| |
| @@ -344,7 +344,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Gemma2DecoderLayer, |
| |
| |
| |
| |
| @@ -434,7 +434,7 @@ class Gemma3PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Gemma3DecoderLayer, |
| |
| |
| |
| |
| @@ -1490,7 +1490,7 @@ class Gemma3nPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Gemma3nTextDecoderLayer, |
| |
| |
| |
| |
| @@ -331,7 +331,7 @@ class GlmPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": GlmDecoderLayer, |
| |
| |
| |
| |
| @@ -335,7 +335,7 @@ class Glm4PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Glm4DecoderLayer, |
| |
| |
| |
| |
| @@ -403,7 +403,7 @@ class Glm4MoePreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Glm4MoeDecoderLayer, |
| |
| |
| |
| |
| @@ -310,7 +310,7 @@ class Glm4MoeDecoderLayer(DeepseekV3DecoderLayer): |
| |
| |
| class Glm4MoePreTrainedModel(DeepseekV3PreTrainedModel): |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| |
| |
| class Glm4MoeModel(DeepseekV3Model): |
| |
| |
| |
| |
| @@ -407,7 +407,7 @@ class Glm4vPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| |
| def _init_weights(self, module): |
| |
| |
| |
| |
| @@ -283,7 +283,7 @@ class GotOcr2PreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = False |
| _supports_sdpa = False |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_flex_attn = False |
| _supports_attention_backend = True |
| |
| |
| |
| |
| |
| @@ -563,7 +563,7 @@ class GPT2PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_attention_backend = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def __init__(self, *inputs, **kwargs): |
| super().__init__(*inputs, **kwargs) |
| |
| |
| |
| |
| @@ -477,7 +477,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): |
| _no_split_modules = ["GPTNeoBlock"] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn = True |
| - _supports_static_cache = False # TODO: needs a HybridCache |
| + _can_compile_fullgraph = False # TODO: needs a HybridCache |
| |
| def __init__(self, *inputs, **kwargs): |
| super().__init__(*inputs, **kwargs) |
| |
| |
| |
| |
| @@ -364,7 +364,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": GPTNeoXDecoderLayer, |
| |
| |
| |
| |
| @@ -48,7 +48,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): |
| _no_split_modules = ["GPTNeoXJapaneseLayer"] |
| _skip_keys_device_placement = "past_key_values" |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| """Initialize the weights""" |
| |
| |
| |
| |
| @@ -472,7 +472,7 @@ class GPTJPreTrainedModel(PreTrainedModel): |
| _no_split_modules = ["GPTJBlock"] |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn = True |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_param_buffer_assignment = False |
| |
| def __init__(self, *inputs, **kwargs): |
| |
| |
| |
| |
| @@ -309,7 +309,7 @@ class GranitePreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": GraniteDecoderLayer, |
| |
| |
| |
| |
| @@ -592,7 +592,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| |
| |
| |
| |
| @@ -1208,7 +1208,7 @@ class GraniteMoeHybridPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| _is_stateful = True |
| |
| def _init_weights(self, module): |
| |
| |
| |
| |
| @@ -510,7 +510,7 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| |
| |
| |
| |
| @@ -316,7 +316,7 @@ class HeliumPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": HeliumDecoderLayer, |
| |
| |
| |
| |
| @@ -880,7 +880,7 @@ class IdeficsPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| |
| _supports_flash_attn = True |
| - _supports_static_cache = False # IDEFICS cannot compile due to dynamic control flow when checking inputs |
| + _can_compile_fullgraph = False # IDEFICS cannot compile due to dynamic control flow when checking inputs |
| _supports_attention_backend = True |
| |
| def _init_weights(self, module): |
| |
| |
| |
| |
| @@ -340,7 +340,7 @@ class InstructBlipPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| _no_split_modules = [ |
| "InstructBlipQFormerEmbeddings", |
| @@ -1354,7 +1354,7 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati |
| config: InstructBlipConfig |
| main_input_name = "pixel_values" |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 |
| |
| def __init__(self, config: InstructBlipConfig): |
| |
| |
| |
| |
| @@ -827,7 +827,7 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| _no_split_modules = [ |
| "InstructBlipVideoQFormerEmbeddings", |
| @@ -1360,7 +1360,7 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel |
| config: InstructBlipVideoConfig |
| main_input_name = "pixel_values" |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8 |
| |
| def __init__(self, config: InstructBlipVideoConfig): |
| |
| |
| |
| |
| @@ -524,7 +524,7 @@ class InternVLPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_flex_attn = True |
| _supports_attention_backend = True |
| |
| |
| |
| |
| |
| @@ -63,7 +63,7 @@ class JanusPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_param_buffer_assignment = False |
| |
| def _init_weights(self, module): |
| @@ -1123,7 +1123,7 @@ def forward( |
| |
| class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def __init__(self, config: JanusConfig): |
| super().__init__(config) |
| |
| |
| |
| |
| @@ -390,7 +390,7 @@ class JanusPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_param_buffer_assignment = False |
| |
| def _init_weights(self, module): |
| @@ -982,7 +982,7 @@ def forward( |
| |
| class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def __init__(self, config: JanusConfig): |
| super().__init__(config) |
| |
| |
| |
| |
| @@ -542,7 +542,7 @@ class Lfm2PreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Lfm2DecoderLayer, |
| |
| |
| |
| |
| @@ -401,7 +401,7 @@ def forward( |
| |
| |
| class Lfm2PreTrainedModel(LlamaPreTrainedModel): |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| |
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| |
| |
| |
| |
| @@ -315,7 +315,7 @@ class LlamaPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": LlamaDecoderLayer, |
| |
| |
| |
| |
| @@ -437,7 +437,7 @@ class Llama4PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| |
| def _init_weights(self, module): |
| |
| |
| |
| |
| @@ -121,7 +121,7 @@ class LlavaPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_flex_attn = True |
| _supports_attention_backend = True |
| |
| |
| |
| |
| |
| @@ -232,7 +232,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_flex_attn = True |
| _supports_attention_backend = True |
| |
| |
| |
| |
| |
| @@ -173,7 +173,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_flex_attn = True |
| _supports_attention_backend = True |
| |
| |
| |
| |
| |
| @@ -286,7 +286,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_flex_attn = True |
| _supports_attention_backend = True |
| |
| |
| |
| |
| |
| @@ -1250,7 +1250,7 @@ class LongT5PreTrainedModel(PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["LongT5Block"] |
| |
| - _supports_static_cache = False # TODO: @raushan more involved due to local/global attn |
| + _can_compile_fullgraph = False # TODO: @raushan more involved due to local/global attn |
| |
| @property |
| # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs |
| |
| |
| |
| |
| @@ -525,7 +525,7 @@ class M2M100PreTrainedModel(PreTrainedModel): |
| _supports_flex_attn = True |
| |
| # Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| |
| def _init_weights(self, module): |
| std = self.config.init_std |
| |
| |
| |
| |
| @@ -467,7 +467,7 @@ class MarianPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]): |
| std = self.config.init_std |
| |
| |
| |
| |
| @@ -492,7 +492,7 @@ class MBartPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| std = self.config.init_std |
| |
| |
| |
| |
| @@ -1376,7 +1376,7 @@ class MimiPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| """Initialize the weights""" |
| |
| |
| |
| |
| @@ -588,8 +588,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| - # Note: only supports MiniMaxCache |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "router_logits": OutputRecorder(MiniMaxSparseMoeBlock, index=1), |
| |
| |
| |
| |
| @@ -470,8 +470,7 @@ def forward( |
| |
| |
| class MiniMaxPreTrainedModel(MixtralPreTrainedModel): |
| - # Note: only supports MiniMaxCache |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| _can_record_outputs = { |
| "router_logits": OutputRecorder(MiniMaxSparseMoeBlock, index=1), |
| "hidden_states": MiniMaxDecoderLayer, |
| |
| |
| |
| |
| @@ -260,7 +260,7 @@ class MistralPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": MistralDecoderLayer, |
| |
| |
| |
| |
| @@ -186,7 +186,7 @@ class Mistral3PreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_flex_attn = True |
| _supports_attention_backend = True |
| |
| |
| |
| |
| |
| @@ -388,7 +388,7 @@ class MixtralPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "router_logits": OutputRecorder(MixtralSparseMoeBlock, index=1), |
| |
| |
| |
| |
| @@ -277,7 +277,7 @@ class MixtralRotaryEmbedding(MistralRotaryEmbedding): |
| |
| |
| class MixtralPreTrainedModel(MistralPreTrainedModel): |
| - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| _can_record_outputs = { |
| "router_logits": OutputRecorder(MixtralSparseMoeBlock, index=1), |
| "hidden_states": MixtralDecoderLayer, |
| |
| |
| |
| |
| @@ -850,7 +850,7 @@ class MllamaPreTrainedModel(PreTrainedModel): |
| "MllamaSelfAttentionDecoderLayer", |
| ] |
| |
| - _supports_static_cache = False # static cache cannot have different shapes for each layer |
| + _can_compile_fullgraph = False # static cache cannot have different shapes for each layer |
| _supports_sdpa = True |
| _supports_flash_attn = True |
| _supports_flex_attn = True |
| @@ -1449,7 +1449,7 @@ def forward( |
| ) |
| class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): |
| config: MllamaTextConfig |
| - _supports_static_cache = True # only the LLM without cross attn can do compile |
| + _can_compile_fullgraph = True # only the LLM without cross attn can do compile |
| base_model_prefix = "language_model" |
| _tied_weights_keys = ["lm_head.weight"] |
| |
| |
| |
| |
| |
| @@ -224,7 +224,7 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = False |
| _supports_gradient_checkpointing = True |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": ModernBertDecoderLayer, |
| |
| |
| |
| |
| @@ -401,7 +401,7 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = False |
| _supports_gradient_checkpointing = True |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": ModernBertDecoderLayer, |
| |
| |
| |
| |
| @@ -462,7 +462,7 @@ class MoonshinePreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| # TODO arthur, how do we separate when it cross / self coming from different layer? |
| |
| def _init_weights(self, module): |
| |
| |
| |
| |
| @@ -497,7 +497,7 @@ class MoonshinePreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| # TODO arthur, how do we separate when it cross / self coming from different layer? |
| |
| def _init_weights(self, module): |
| |
| |
| |
| |
| @@ -757,7 +757,7 @@ class MT5PreTrainedModel(PreTrainedModel): |
| base_model_prefix = "transformer" |
| is_parallelizable = True |
| supports_gradient_checkpointing = True |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| _no_split_modules = ["MT5Block"] |
| _keep_in_fp32_modules = ["wo"] |
| |
| |
| |
| |
| @@ -589,7 +589,7 @@ class NemotronPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| |
| |
| |
| |
| @@ -294,7 +294,7 @@ class OlmoPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": OlmoDecoderLayer, |
| |
| |
| |
| |
| @@ -299,7 +299,7 @@ class Olmo2PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Olmo2DecoderLayer, |
| |
| |
| |
| |
| @@ -706,7 +706,7 @@ class OlmoePreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| |
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| |
| |
| |
| |
| @@ -313,7 +313,7 @@ class OPTPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| std = self.config.init_std |
| |
| |
| |
| |
| @@ -114,7 +114,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): |
| _no_split_modules = ["PaliGemmaMultiModalProjector"] |
| _skip_keys_device_placement = "past_key_values" |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| |
| |
| |
| @@ -458,7 +458,7 @@ class PegasusPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| std = self.config.init_std |
| |
| |
| |
| |
| @@ -758,7 +758,7 @@ class PegasusXPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = False |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| std = self.config.init_std |
| |
| |
| |
| |
| @@ -95,7 +95,7 @@ class PerceptionLMPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_flex_attn = True |
| _supports_attention_backend = True |
| |
| |
| |
| |
| |
| @@ -389,7 +389,7 @@ class PersimmonPreTrainedModel(PreTrainedModel): |
| _no_split_modules = ["PersimmonDecoderLayer"] |
| _skip_keys_device_placement = "past_key_values" |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_sdpa = True |
| _supports_flash_attn = True |
| _supports_attention_backend = True |
| |
| |
| |
| |
| @@ -299,7 +299,7 @@ class PhiPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": PhiDecoderLayer, |
| |
| |
| |
| |
| @@ -291,7 +291,7 @@ class Phi3PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Phi3DecoderLayer, |
| |
| |
| |
| |
| @@ -1593,7 +1593,7 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Phi4MultimodalDecoderLayer, |
| |
| |
| |
| |
| @@ -890,7 +890,7 @@ class PhimoePreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| |
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| |
| |
| |
| |
| @@ -351,7 +351,7 @@ def forward( |
| class Pix2StructPreTrainedModel(PreTrainedModel): |
| config: Pix2StructConfig |
| |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| |
| @property |
| def dummy_inputs(self): |
| |
| |
| |
| |
| @@ -577,7 +577,7 @@ class Pop2PianoPreTrainedModel(PreTrainedModel): |
| is_parallelizable = False |
| supports_gradient_checkpointing = True |
| |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| _no_split_modules = ["Pop2PianoBlock"] |
| _keep_in_fp32_modules = ["wo"] |
| |
| |
| |
| |
| |
| @@ -263,7 +263,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Qwen2DecoderLayer, |
| |
| |
| |
| |
| @@ -87,7 +87,7 @@ class Qwen2_5OmniPreTrainedModel(PreTrainedModel): |
| _skip_keys_device_placement = "past_key_values" |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| _supports_attention_backend = True |
| |
| def _init_weights(self, module): |
| |
| |
| |
| |
| @@ -1134,7 +1134,7 @@ def get_text_config(self, decoder=False): |
| |
| class Qwen2_5OmniPreTrainedModel(Qwen2_5_VLPreTrainedModel): |
| config: Qwen2_5OmniConfig |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| |
| def _init_weights(self, module): |
| # important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only |
| |
| |
| |
| |
| @@ -326,7 +326,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| |
| def _init_weights(self, module): |
| |
| |
| |
| |
| @@ -660,7 +660,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| |
| def _init_weights(self, module): |
| |
| |
| |
| |
| @@ -289,7 +289,7 @@ class Qwen3PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Qwen3DecoderLayer, |
| |
| |
| |
| |
| @@ -411,7 +411,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| - _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| + _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "router_logits": OutputRecorder(Qwen3MoeSparseMoeBlock, index=1), |
| |
| |
| |
| |
| @@ -293,7 +293,7 @@ class SmolLM3PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": SmolLM3DecoderLayer, |
| |
| |
| |
| |
| @@ -621,7 +621,7 @@ class StableLmPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| |
| |
| |
| |
| @@ -297,7 +297,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": Starcoder2DecoderLayer, |
| |
| |
| |
| |
| @@ -766,7 +766,7 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): |
| base_model_prefix = "switch_transformers" |
| supports_gradient_checkpointing = True |
| |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| _no_split_modules = ["SwitchTransformersBlock"] |
| |
| @property |
| |
| |
| |
| |
| @@ -771,7 +771,7 @@ class T5PreTrainedModel(PreTrainedModel): |
| base_model_prefix = "transformer" |
| is_parallelizable = True |
| supports_gradient_checkpointing = True |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| _no_split_modules = ["T5Block"] |
| _keep_in_fp32_modules = ["wo"] |
| |
| |
| |
| |
| @@ -585,7 +585,7 @@ class T5GemmaPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": T5GemmaDecoderLayer, |
| |
| |
| |
| |
| @@ -255,7 +255,7 @@ class UdopPreTrainedModel(PreTrainedModel): |
| base_model_prefix = "transformer" |
| supports_gradient_checkpointing = True |
| |
| - _supports_static_cache = False |
| + _can_compile_fullgraph = False |
| _keep_in_fp32_modules = ["wo"] |
| |
| def _init_weights(self, module): |
| |
| |
| |
| |
| @@ -508,7 +508,7 @@ class UMT5PreTrainedModel(PreTrainedModel): |
| base_model_prefix = "transformer" |
| supports_gradient_checkpointing = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _no_split_modules = ["UMT5Block"] |
| _keep_in_fp32_modules = ["wo"] |
| |
| |
| |
| |
| |
| @@ -135,7 +135,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| |
| def _init_weights(self, module): |
| |
| |
| |
| |
| @@ -122,7 +122,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_flex_attn = True |
| _supports_attention_backend = True |
| |
| |
| |
| |
| |
| @@ -236,7 +236,7 @@ class VoxtralPreTrainedModel(PreTrainedModel): |
| _supports_flex_attn = True |
| _supports_cache_class = True |
| _supports_attention_backend = True |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| # important: this ported version of Voxtral isn't meant for training from scratch - only |
| |
| |
| |
| |
| @@ -47,7 +47,7 @@ class VoxtralPreTrainedModel(Qwen2AudioPreTrainedModel): |
| _supports_flex_attn = True |
| _supports_cache_class = True |
| _supports_attention_backend = True |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| |
| |
| |
| |
| |
| |
| @@ -553,7 +553,7 @@ class WhisperPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_flex_attn = True |
| |
| - _supports_static_cache = True |
| + _can_compile_fullgraph = True |
| |
| def _init_weights(self, module): |
| std = self.config.init_std |
| |
| |
| |
| |
| @@ -965,8 +965,9 @@ class ClassAttrs: |
| _supports_flex_attn = r""" |
| Whether the model's attention implementation supports FlexAttention. |
| """ |
| - _supports_static_cache = r""" |
| - Whether the model supports a `StaticCache` instance as `past_key_values`. |
| + _can_compile_fullgraph = r""" |
| + Whether the model can `torch.compile` fullgraph without graph breaks. Models will auto-compile if this flag is set to `True` |
| + in inference, if a compilable cache is used. |
| """ |
| _supports_attention_backend = r""" |
| Whether the model supports attention interface functions. This flag signal that the model can be used as an efficient backend in TGI and vLLM. |
| |
| |
| |
| |
| @@ -1758,7 +1758,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self): |
| to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache. |
| """ |
| for model_class in self.all_generative_model_classes: |
| - if not model_class._supports_static_cache: |
| + if not model_class._can_compile_fullgraph: |
| self.skipTest(reason="This model does not support the static cache format") |
| |
| config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
| @@ -1978,7 +1978,7 @@ def test_generate_with_static_cache(self): |
| """ |
| set_model_tester_for_less_flaky_test(self) |
| for model_class in self.all_generative_model_classes: |
| - if not model_class._supports_static_cache: |
| + if not model_class._can_compile_fullgraph: |
| self.skipTest(reason="This model does not support the static cache format") |
| |
| config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
| @@ -2081,8 +2081,8 @@ def test_generate_compile_model_forward(self): |
| set_model_tester_for_less_flaky_test(self) |
| for model_class in self.all_generative_model_classes: |
| # 1. Test exclusion criteria |
| - if not model_class._supports_static_cache: |
| - self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") |
| + if not model_class._can_compile_fullgraph: |
| + self.skipTest("This model doesn't support compilation without graph breaks") |
| |
| # 2. Prepares two sets of inputs |
| config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=4) |
| @@ -2195,8 +2195,8 @@ def test_generate_compilation_all_outputs(self): |
| In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered. |
| """ |
| for model_class in self.all_generative_model_classes: |
| - if not model_class._supports_static_cache: |
| - self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") |
| + if not model_class._can_compile_fullgraph: |
| + self.skipTest("This model doesn't support compilation without graph breaks") |
| |
| config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
| if self.has_attentions: |
| |
| |
| |
| |
| @@ -4400,7 +4400,7 @@ def test_custom_4d_attention_mask(self): |
| set_model_tester_for_less_flaky_test(self) |
| |
| for model_class in self.all_generative_model_classes: |
| - if not model_class._supports_static_cache: |
| + if not model_class._can_compile_fullgraph: |
| self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks") |
| config, _ = self.model_tester.prepare_config_and_inputs_for_common() |
| set_config_for_less_flaky_test(config) |
|
|