File size: 2,348 Bytes
c882c3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

from typing import Any, Dict, List, Optional, Union, Callable
import torch
from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
from transformers.generation.utils import GenerationConfig, GenerateOutput
from transformers.utils import ModelOutput


class TSGenerationMixin(GenerationMixin):
    @torch.no_grad()
    def generate(self,
        inputs: Optional[torch.Tensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
        synced_gpus: Optional[bool] = None,
        assistant_model: Optional["PreTrainedModel"] = None,
        streamer: Optional["BaseStreamer"] = None,
        negative_prompt_ids: Optional[torch.Tensor] = None,
        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
        revin: Optional[bool] = True,
        patch_len:Optional[int] = None,
        stride_len:Optional[int]= None,
        max_output_length:Optional[int] = None,
        inference_patch_len: Optional[int] = None,

        **kwargs,
    ) -> Union[GenerateOutput, torch.Tensor]:
        

        if len(inputs.shape) != 3:
            raise ValueError('Input shape must be: [batch_size, seq_len, n_vars]')
        
        if revin:
            means = inputs.mean(dim=1, keepdim=True)
            stdev = inputs.std(dim=1, keepdim=True, unbiased=False) + 1e-5
            inputs = (inputs - means) / stdev

        

        model_inputs = {
            "input" : inputs,
            "patch_len" : patch_len,
            "stride" : stride_len,
            "target_dim" : max_output_length
        }


        outputs = self(**model_inputs) #[batch_size,target_dim,n_vars]

        outputs = outputs["prediction"]


        if revin:

            outputs = (outputs * stdev) + means

        return outputs
    
    
    def _update_model_kwargs_for_generation(
            self,
            outputs: ModelOutput,
            model_kwargs: Dict[str, Any],
            horizon_length: int = 1,
            is_encoder_decoder: bool = False,
            standardize_cache_format: bool = False,
    ) -> Dict[str, Any]:
        
        return model_kwargs