pchen182224 commited on
Commit
e44570d
·
verified ·
1 Parent(s): 8bfb2d8

Delete ts_generation_mixin.py

Browse files
Files changed (1) hide show
  1. ts_generation_mixin.py +0 -72
ts_generation_mixin.py DELETED
@@ -1,72 +0,0 @@
1
-
2
- from typing import Any, Dict, List, Optional, Union, Callable
3
- import torch
4
- from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
5
- from transformers.generation.utils import GenerationConfig, GenerateOutput
6
- from transformers.utils import ModelOutput
7
-
8
-
9
- class TSGenerationMixin(GenerationMixin):
10
- @torch.no_grad()
11
- def generate(self,
12
- inputs: Optional[torch.Tensor] = None,
13
- generation_config: Optional[GenerationConfig] = None,
14
- logits_processor: Optional[LogitsProcessorList] = None,
15
- stopping_criteria: Optional[StoppingCriteriaList] = None,
16
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
17
- synced_gpus: Optional[bool] = None,
18
- assistant_model: Optional["PreTrainedModel"] = None,
19
- streamer: Optional["BaseStreamer"] = None,
20
- negative_prompt_ids: Optional[torch.Tensor] = None,
21
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
22
- revin: Optional[bool] = True,
23
- patch_len:Optional[int] = 48,
24
- stride_len:Optional[int]= 48,
25
- max_output_length:Optional[int] = 96,
26
- inference_patch_len: Optional[int] = 48,
27
-
28
- **kwargs,
29
- ) -> Union[GenerateOutput, torch.Tensor]:
30
- if len(inputs.shape) != 3:
31
- raise ValueError('Input shape must be: [batch_size, seq_len, n_vars]')
32
-
33
- if revin:
34
- means = inputs.mean(dim=1, keepdim=True)
35
- stdev = inputs.std(dim=1, keepdim=True, unbiased=False) + 1e-5
36
- inputs = (inputs - means) / stdev
37
-
38
- batch_size,seq_len,n_vars = inputs.shape
39
- num_patch = (max(seq_len, patch_len)-patch_len) // stride_len + 1
40
- outputs = inputs.view(batch_size, num_patch, patch_len, n_vars)
41
- outputs = outputs.transpose(2, 3)
42
-
43
-
44
- model_inputs = {
45
- "input" : outputs,
46
- }
47
-
48
-
49
- outputs = self(**model_inputs) #[batch_size,target_dim,n_vars]
50
-
51
- outputs = outputs["prediction"]
52
-
53
-
54
- if revin:
55
-
56
- outputs = (outputs * stdev) + means
57
-
58
- return outputs
59
-
60
-
61
- def _update_model_kwargs_for_generation(
62
- self,
63
- outputs: ModelOutput,
64
- model_kwargs: Dict[str, Any],
65
- horizon_length: int = 1,
66
- is_encoder_decoder: bool = False,
67
- standardize_cache_format: bool = False,
68
- ) -> Dict[str, Any]:
69
-
70
- return model_kwargs
71
-
72
-