|
|
| from typing import Any, Dict, List, Optional, Union, Callable |
| import torch |
| from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList |
| from transformers.generation.utils import GenerationConfig, GenerateOutput |
| from transformers.utils import ModelOutput |
|
|
|
|
| class TSGenerationMixin(GenerationMixin): |
| @torch.no_grad() |
| def generate(self, |
| inputs: Optional[torch.Tensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
| synced_gpus: Optional[bool] = None, |
| assistant_model: Optional["PreTrainedModel"] = None, |
| streamer: Optional["BaseStreamer"] = None, |
| negative_prompt_ids: Optional[torch.Tensor] = None, |
| negative_prompt_attention_mask: Optional[torch.Tensor] = None, |
| revin: Optional[bool] = True, |
| patch_len:Optional[int] = None, |
| stride_len:Optional[int]= None, |
| max_output_length:Optional[int] = None, |
| inference_patch_len: Optional[int] = None, |
| |
| **kwargs, |
| ) -> Union[GenerateOutput, torch.Tensor]: |
| |
|
|
| if len(inputs.shape) != 3: |
| raise ValueError('Input shape must be: [batch_size, seq_len, n_vars]') |
| |
| if revin: |
| means = inputs.mean(dim=1, keepdim=True) |
| stdev = inputs.std(dim=1, keepdim=True, unbiased=False) + 1e-5 |
| inputs = (inputs - means) / stdev |
|
|
| |
|
|
| model_inputs = { |
| "input" : inputs, |
| "patch_len" : patch_len, |
| "stride" : stride_len, |
| "target_dim" : max_output_length |
| } |
|
|
|
|
| outputs = self(**model_inputs) |
|
|
| outputs = outputs["prediction"] |
|
|
|
|
| if revin: |
|
|
| outputs = (outputs * stdev) + means |
|
|
| return outputs |
| |
| |
| def _update_model_kwargs_for_generation( |
| self, |
| outputs: ModelOutput, |
| model_kwargs: Dict[str, Any], |
| horizon_length: int = 1, |
| is_encoder_decoder: bool = False, |
| standardize_cache_format: bool = False, |
| ) -> Dict[str, Any]: |
| |
| return model_kwargs |
| |
|
|
|
|