diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e3657550d0e7..2785ca058dca 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -15,6 +15,7 @@ # limitations under the License. import copy import inspect +import os import warnings from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -3222,6 +3223,16 @@ def _sample( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + def model_forward(model, *args, **kwargs): + return model.forward(*args, **kwargs) + + if isinstance(model_kwargs.get("past_key_values"), StaticCache): + if self.device.type == "cuda": + logger.warning_once("Using `torch.compile`.") + os.environ["TOKENIZERS_PARALLELISM"] = "0" + model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True) + + i = 0 while self._has_unfinished_sequences( this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length ): @@ -3232,8 +3243,11 @@ def _sample( model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) - # forward pass to get next token - outputs = self(**model_inputs, return_dict=True) + if i == 0: + outputs = self(**model_inputs, return_dict=True) + i += 1 + else: + outputs = model_forward(self, return_dict=True, **model_inputs) # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping model_kwargs = self._update_model_kwargs_for_generation(