File size: 2,348 Bytes
c882c3e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from typing import Any, Dict, List, Optional, Union, Callable
import torch
from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
from transformers.generation.utils import GenerationConfig, GenerateOutput
from transformers.utils import ModelOutput
class TSGenerationMixin(GenerationMixin):
@torch.no_grad()
def generate(self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
revin: Optional[bool] = True,
patch_len:Optional[int] = None,
stride_len:Optional[int]= None,
max_output_length:Optional[int] = None,
inference_patch_len: Optional[int] = None,
**kwargs,
) -> Union[GenerateOutput, torch.Tensor]:
if len(inputs.shape) != 3:
raise ValueError('Input shape must be: [batch_size, seq_len, n_vars]')
if revin:
means = inputs.mean(dim=1, keepdim=True)
stdev = inputs.std(dim=1, keepdim=True, unbiased=False) + 1e-5
inputs = (inputs - means) / stdev
model_inputs = {
"input" : inputs,
"patch_len" : patch_len,
"stride" : stride_len,
"target_dim" : max_output_length
}
outputs = self(**model_inputs) #[batch_size,target_dim,n_vars]
outputs = outputs["prediction"]
if revin:
outputs = (outputs * stdev) + means
return outputs
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
horizon_length: int = 1,
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
return model_kwargs
|