| |
| |
| |
| |
| @@ -349,7 +349,7 @@ In case you are using Sink Cache, you have to crop your inputs to that maximum l |
| >>> user_prompts = ["Hello, what's your name?", "Btw, yesterday I was on a rock concert."] |
| |
| >>> past_key_values = DynamicCache() |
| ->>> max_cache_length = past_key_values.get_max_length() |
| +>>> max_cache_length = past_key_values.get_max_cache_shape() |
| |
| >>> messages = [] |
| >>> for prompt in user_prompts: |
| |
| |
| |
| |
| @@ -29,6 +29,8 @@ class Cache(torch.nn.Module): |
| Base, abstract class for all caches. The actual data structure is specific to each subclass. |
| """ |
| |
| + is_compileable = False |
| + |
| def __init__(self): |
| super().__init__() |
| |
| @@ -1098,6 +1100,8 @@ class StaticCache(Cache): |
| ``` |
| """ |
| |
| + is_compileable = True |
| + |
| # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. |
| @deprecate_kwarg("layer_device_map", version="4.52.0") |
| def __init__( |
| @@ -1297,6 +1301,7 @@ class SlidingWindowCache(StaticCache): |
| """ |
| |
| is_sliding = True |
| + is_compileable = True |
| |
| # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. |
| def __init__( |
| @@ -1421,6 +1426,7 @@ def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): |
| super().__init__() |
| self.self_attention_cache = self_attention_cache |
| self.cross_attention_cache = cross_attention_cache |
| + self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False) |
| |
| self.is_updated = {} |
| for layer_idx in range(len(cross_attention_cache.key_cache)): |
| @@ -1612,6 +1618,8 @@ class HybridCache(Cache): |
| ``` |
| """ |
| |
| + is_compileable = True |
| + |
| # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. |
| @deprecate_kwarg("layer_device_map", version="4.52.0") |
| def __init__( |
| @@ -1832,6 +1840,8 @@ class MambaCache: |
| ``` |
| """ |
| |
| + is_compileable = True |
| + |
| # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. |
| def __init__( |
| self, |
| @@ -1975,6 +1985,8 @@ class OffloadedStaticCache(StaticCache): |
| ``` |
| """ |
| |
| + is_compileable = True |
| + |
| @deprecate_kwarg("layer_device_map", version="4.52.0") |
| def __init__( |
| self, |
| |
| |
| |
| |
| @@ -1579,7 +1579,7 @@ def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProces |
| |
| |
| @dataclass |
| -class CompileConfig(object): |
| +class CompileConfig: |
| """ |
| Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`. |
| See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments. |
| @@ -1620,7 +1620,9 @@ class CompileConfig(object): |
| backend: Union[str, Callable] = "inductor" |
| mode: str = "reduce-overhead" |
| options: Optional[dict] = None |
| + # Used to flag our `generate` call to compile on e.g. CPU. Often not optimal, but useful for testing purposes. |
| + _compile_all_devices = None |
| |
| def to_dict(self) -> Dict[str, Any]: |
| """Serializes this instance to a Python dictionary.""" |
| - return copy.deepcopy(self.__dict__) |
| + return copy.deepcopy({key: value for key, value in self.__dict__.items() if key != "_compile_all_devices"}) |
| |
| |
| |
| |
| @@ -3177,9 +3177,11 @@ def _sample( |
| model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) |
| |
| model_forward = self.__call__ |
| - if isinstance(model_kwargs.get("past_key_values"), StaticCache): |
| - if self.device.type == "cuda": |
| - logger.warning_once("Using `torch.compile`.") |
| + if isinstance(model_kwargs.get("past_key_values"), Cache): |
| + is_compileable = model_kwargs["past_key_values"].is_compileable |
| + if is_compileable and ( |
| + self.device.type == "cuda" or generation_config.compile_config._compile_all_devices |
| + ): |
| os.environ["TOKENIZERS_PARALLELISM"] = "0" |
| model_forward = self.get_compiled_call(generation_config.compile_config) |
| |
| |
| |
| |
| |
| @@ -708,7 +708,7 @@ class AriaPreTrainedModel(PreTrainedModel): |
| _supports_flex_attn = True |
| _supports_cache_class = True |
| _supports_quantized_cache = True |
| - _supports_static_cache = True |
| + _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) |
| _supports_attention_backend = False |
| |
| def _init_weights(self, module): |
| @@ -1561,6 +1561,7 @@ def forward( |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| logits_to_keep=logits_to_keep, |
| + cache_position=cache_position, |
| ) |
| |
| logits = outputs[0] |
| |
| |
| |
| |
| @@ -1223,6 +1223,7 @@ def _init_weights(self, module): |
| |
| |
| class AriaPreTrainedModel(LlamaPreTrainedModel): |
| + _supports_static_cache = False # MoE models don't work with torch.compile (dynamic slicing) |
| _supports_attention_backend = False |
| |
| def _init_weights(self, module): |
| @@ -1535,6 +1536,7 @@ def forward( |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| logits_to_keep=logits_to_keep, |
| + cache_position=cache_position, |
| ) |
| |
| logits = outputs[0] |
| |
| |
| |
| |
| @@ -833,6 +833,7 @@ class DbrxPreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_cache_class = True |
| _supports_quantized_cache = True |
| + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| |
| def _init_weights(self, module: nn.Module): |
| std = self.config.initializer_range |
| |
| |
| |
| |
| @@ -1802,6 +1802,7 @@ def forward( |
| |
| class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["text_model.lm_head.weight"] |
| + _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable |
| |
| def __init__(self, config): |
| super().__init__(config) |
| |
| |
| |
| |
| @@ -1113,6 +1113,7 @@ def forward(**super_kwargs): |
| |
| class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["text_model.lm_head.weight"] |
| + _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable |
| |
| def __init__(self, config): |
| super().__init__(config) |
| |
| |
| |
| |
| @@ -52,7 +52,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): |
| _skip_keys_device_placement = "past_key_values" |
| _supports_cache_class = True |
| _supports_quantized_cache = True |
| - _supports_static_cache = True |
| + _supports_static_cache = False # TODO (fix me): compilation fails due to a stide error? |
| |
| def _init_weights(self, module): |
| """Initialize the weights""" |
| |
| |
| |
| |
| @@ -843,6 +843,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_cache_class = True |
| _supports_quantized_cache = True |
| + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| |
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| |
| |
| |
| |
| @@ -917,6 +917,7 @@ class IdeficsPreTrainedModel(PreTrainedModel): |
| _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"] |
| _supports_sdpa = True |
| _supports_cache_class = True |
| + _supports_static_cache = False # IDEFICS cannot compile due to dynamic control flow when checking inputs |
| |
| def _init_weights(self, module): |
| # important: this ported version of Idefics isn't meant for training from scratch - only |
| |
| |
| |
| |
| @@ -485,7 +485,7 @@ class MixtralPreTrainedModel(PreTrainedModel): |
| _supports_flex_attn = True |
| _supports_cache_class = True |
| _supports_quantized_cache = True |
| - _supports_static_cache = True |
| + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| _supports_attention_backend = True |
| |
| def _init_weights(self, module): |
| |
| |
| |
| |
| @@ -45,7 +45,9 @@ |
| MistralForSequenceClassification, |
| MistralForTokenClassification, |
| MistralModel, |
| + MistralPreTrainedModel, |
| MistralRMSNorm, |
| + MistralRotaryEmbedding, |
| ) |
| from .configuration_mixtral import MixtralConfig |
| |
| @@ -313,6 +315,14 @@ def forward( |
| return outputs |
| |
| |
| +class MixtralRotaryEmbedding(MistralRotaryEmbedding): |
| + pass |
| + |
| + |
| +class MixtralPreTrainedModel(MistralPreTrainedModel): |
| + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| + |
| + |
| class MixtralModel(MistralModel): |
| def __init__(self, config: MixtralConfig): |
| super().__init__(config) |
| |
| |
| |
| |
| @@ -767,7 +767,7 @@ class OlmoePreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_cache_class = True |
| _supports_quantized_cache = True |
| - _supports_static_cache = True |
| + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| |
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| |
| |
| |
| |
| @@ -912,7 +912,7 @@ class PhimoePreTrainedModel(PreTrainedModel): |
| _supports_sdpa = True |
| _supports_cache_class = True |
| _supports_quantized_cache = True |
| - _supports_static_cache = True |
| + _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) |
| |
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| |
| |
| |
| |
| @@ -332,7 +332,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
| - _supports_static_cache = True |
| + _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` |
| |
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| |
| |
| |
| |
| @@ -882,7 +882,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
| - _supports_static_cache = True |
| + _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` |
| |
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| |
| |
| |
| |
| @@ -1978,52 +1978,82 @@ def test_generate_with_quant_cache(self): |
| model.generate(**generation_kwargs, **inputs_dict) |
| |
| @pytest.mark.generate |
| - @require_torch_accelerator |
| - @slow |
| def test_generate_compile_model_forward(self): |
| """ |
| - Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. Tests |
| - end-to-end compilation and forward pass compilation only. |
| + Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results. |
| ⚠️ Runs two sequential generations to ensure the cache doesn't get stuck after the first compiled run! ⚠️ |
| """ |
| for model_class in self.all_generative_model_classes: |
| if not model_class._supports_static_cache: |
| - self.skipTest("This model doesn't support static cache") |
| + self.skipTest("This model doesn't support static cache (= no expectations of compilation support)") |
| |
| - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| + config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=4) |
| |
| model = model_class(config).to(torch_device) |
| model.eval() # otherwise `self.training` is `True` -- this flag is used at attn mask creation time |
| |
| - input_ids = inputs_dict["input_ids"].to(torch_device) |
| + main_input = inputs_dict[model.main_input_name].to(torch_device) |
| # creates two sets of *different* inputs with the same shape |
| - half_batch_size = input_ids.shape[0] // 2 |
| - input_ids_sets = [input_ids[:half_batch_size, :], input_ids[half_batch_size : half_batch_size * 2, :]] |
| - self.assertTrue(input_ids_sets[0].shape == input_ids_sets[1].shape) |
| + half_batch_size = main_input.shape[0] // 2 |
| + input_1 = {} |
| + input_2 = {} |
| + for key, value in inputs_dict.items(): |
| + if isinstance(value, torch.Tensor): |
| + input_1[key] = value[:half_batch_size, :].to(torch_device) |
| + input_2[key] = value[half_batch_size : half_batch_size * 2, :].to(torch_device) |
| + else: |
| + input_1[key] = value |
| + input_2[key] = value |
| + model_input_sets = [input_1, input_2] |
| + self.assertTrue( |
| + model_input_sets[0][model.main_input_name].shape == model_input_sets[1][model.main_input_name].shape |
| + ) |
| + |
| + # compilation-specific setup |
| + torch.compiler.reset() # prevent cached compilation from being used in the test |
| + has_defined_cache_implementation = model.generation_config.cache_implementation is not None |
| + model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU) |
| |
| generation_kwargs = { |
| "do_sample": False, |
| - "max_new_tokens": 10, |
| + "max_new_tokens": 5, |
| "return_dict_in_generate": True, |
| "output_scores": True, |
| - "cache_implementation": "static", |
| } |
| |
| # get eager + dynamic cache results for future comparison |
| dynamic_outputs = [] |
| - for model_inputs in input_ids_sets: |
| - dynamic_outputs.append(model.generate(model_inputs, **generation_kwargs)) |
| - |
| - # get compiled results |
| - generation_config = copy.deepcopy(model.generation_config) |
| - generation_config.update(**generation_kwargs) |
| - torch.compiler.reset() |
| + for model_inputs in model_input_sets: |
| + gen_out = model.generate(**model_inputs, **generation_kwargs) |
| + dynamic_outputs.append(gen_out) |
| + # sanity checks for the default cache implementation |
| + if not has_defined_cache_implementation: |
| + decoder_cache = ( |
| + gen_out.past_key_values.self_attention_cache |
| + if config.is_encoder_decoder |
| + else gen_out.past_key_values |
| + ) |
| + self.assertTrue(isinstance(decoder_cache, DynamicCache)) |
| + self.assertFalse(decoder_cache.is_compileable) |
| + self.assertFalse(hasattr(model, "_compiled_call")) # our auto compile should NOT have been called |
| |
| - model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") |
| + # get compiled results -- relies on the automatic compilation triggered by specific "cache_implementation" |
| + if not has_defined_cache_implementation: |
| + generation_kwargs["cache_implementation"] = "static" |
| |
| compiled_outputs = [] |
| - for model_inputs in input_ids_sets: |
| - compiled_outputs.append(model.generate(model_inputs, generation_config=generation_config)) |
| + for model_inputs in model_input_sets: |
| + gen_out = model.generate(**model_inputs, **generation_kwargs) |
| + compiled_outputs.append(gen_out) |
| + # sanity checks |
| + decoder_cache = ( |
| + gen_out.past_key_values.self_attention_cache |
| + if config.is_encoder_decoder |
| + else gen_out.past_key_values |
| + ) |
| + self.assertFalse(isinstance(decoder_cache, DynamicCache)) |
| + self.assertTrue(decoder_cache.is_compileable) |
| + self.assertTrue(hasattr(model, "_compiled_call")) # our auto compile should have been called |
| |
| for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs): |
| self._check_similar_generate_outputs(dynamic_result, compiled_result) |
| |
| |
| |
| |
| @@ -331,11 +331,6 @@ def test_model_rope_scaling(self, scaling_type): |
| def test_batching_equivalence(self): |
| pass |
| |
| - # TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow |
| - @unittest.skip("Chameleon is not compatible with end-to-end generation compilation") |
| - def test_generate_compile_model_forward(self): |
| - pass |
| - |
| |
| @require_torch |
| class ChameleonIntegrationTest(unittest.TestCase): |
| |
| |
| |
| |
| @@ -368,10 +368,6 @@ def test_disk_offload_safetensors(self): |
| def test_disk_offload_bin(self): |
| pass |
| |
| - @unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.") |
| - def test_generate_compile_model_forward(self): |
| - pass |
| - |
| |
| @require_torch |
| class DbrxModelIntegrationTest(unittest.TestCase): |
| |
| |
| |
| |
| @@ -780,10 +780,6 @@ def test_contrastive_generate_low_memory(self): |
| def test_custom_4d_attention_mask(self): |
| pass |
| |
| - @unittest.skip(reason="IDEFICS cannot compile due to dynamic control flow when checking inputs") |
| - def test_generate_compile_model_forward(self): |
| - pass |
| - |
| @unittest.skip(reason="We only test the model that takes in multiple images") |
| def test_model(self): |
| pass |
| |
| |
| |
| |
| @@ -332,10 +332,6 @@ def test_beam_search_low_memory(self): |
| def test_generate_from_inputs_embeds_with_static_cache(self): |
| pass |
| |
| - @unittest.skip(reason="Can't compile fullgraph due to dynamic control flow in `prepare_inputs_for_generate`") |
| - def test_generate_compile_model_forward(self): |
| - pass |
| - |
| |
| @require_torch |
| class Qwen2VLIntegrationTest(unittest.TestCase): |
| |
| |
| |
| |
| @@ -1602,6 +1602,11 @@ def test_labels_sequence_max_length_error_after_changing_config(self): |
| with self.assertRaises(ValueError): |
| model(input_features=input_features, labels=labels) |
| |
| + # TODO (joao, eustache): fix me :) |
| + @unittest.skip(reason="Whisper's custom generate is not consistent regarding the cache return types") |
| + def test_generate_compile_model_forward(self): |
| + pass |
| + |
| |
| @require_torch |
| @require_torchaudio |
| |
| |
| |
| |
| @@ -364,7 +364,7 @@ def test_sink_cache_iterative_prompts(self): |
| input_ids = gen_out |
| |
| # We went well beyond the cache length |
| - self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5) |
| + self.assertTrue(input_ids.shape[1] > cache.get_max_cache_shape() * 1.5) |
| |
| # And it still produces a coherent english |
| decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) |
|
|