from typing import Any, Dict, List, Optional, Union, Callable import torch from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList from transformers.generation.utils import GenerationConfig, GenerateOutput from transformers.utils import ModelOutput class TSGenerationMixin(GenerationMixin): @torch.no_grad() def generate(self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, revin: Optional[bool] = True, patch_len:Optional[int] = None, stride_len:Optional[int]= None, max_output_length:Optional[int] = None, inference_patch_len: Optional[int] = None, **kwargs, ) -> Union[GenerateOutput, torch.Tensor]: if len(inputs.shape) != 3: raise ValueError('Input shape must be: [batch_size, seq_len, n_vars]') if revin: means = inputs.mean(dim=1, keepdim=True) stdev = inputs.std(dim=1, keepdim=True, unbiased=False) + 1e-5 inputs = (inputs - means) / stdev model_inputs = { "input" : inputs, "patch_len" : patch_len, "stride" : stride_len, "target_dim" : max_output_length } outputs = self(**model_inputs) #[batch_size,target_dim,n_vars] outputs = outputs["prediction"] if revin: outputs = (outputs * stdev) + means return outputs def _update_model_kwargs_for_generation( self, outputs: ModelOutput, model_kwargs: Dict[str, Any], horizon_length: int = 1, is_encoder_decoder: bool = False, standardize_cache_format: bool = False, ) -> Dict[str, Any]: return model_kwargs