| |
| |
| |
| |
| @@ -15,6 +15,7 @@ |
| # limitations under the License. |
| import copy |
| import inspect |
| +import os |
| import warnings |
| from dataclasses import dataclass |
| from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union |
| @@ -3222,6 +3223,16 @@ def _sample( |
| unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) |
| model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) |
| |
| + def model_forward(model, *args, **kwargs): |
| + return model.forward(*args, **kwargs) |
| + |
| + if isinstance(model_kwargs.get("past_key_values"), StaticCache): |
| + if self.device.type == "cuda": |
| + logger.warning_once("Using `torch.compile`.") |
| + os.environ["TOKENIZERS_PARALLELISM"] = "0" |
| + model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True) |
| + |
| + i = 0 |
| while self._has_unfinished_sequences( |
| this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length |
| ): |
| @@ -3232,8 +3243,11 @@ def _sample( |
| model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) |
| model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) |
| |
| - # forward pass to get next token |
| - outputs = self(**model_inputs, return_dict=True) |
| + if i == 0: |
| + outputs = self(**model_inputs, return_dict=True) |
| + i += 1 |
| + else: |
| + outputs = model_forward(self, return_dict=True, **model_inputs) |
| |
| # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping |
| model_kwargs = self._update_model_kwargs_for_generation( |
|
|