Delete ts_generation_mixin.py
Browse files- ts_generation_mixin.py +0 -72
ts_generation_mixin.py
DELETED
|
@@ -1,72 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
from typing import Any, Dict, List, Optional, Union, Callable
|
| 3 |
-
import torch
|
| 4 |
-
from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
|
| 5 |
-
from transformers.generation.utils import GenerationConfig, GenerateOutput
|
| 6 |
-
from transformers.utils import ModelOutput
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
class TSGenerationMixin(GenerationMixin):
|
| 10 |
-
@torch.no_grad()
|
| 11 |
-
def generate(self,
|
| 12 |
-
inputs: Optional[torch.Tensor] = None,
|
| 13 |
-
generation_config: Optional[GenerationConfig] = None,
|
| 14 |
-
logits_processor: Optional[LogitsProcessorList] = None,
|
| 15 |
-
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 16 |
-
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
| 17 |
-
synced_gpus: Optional[bool] = None,
|
| 18 |
-
assistant_model: Optional["PreTrainedModel"] = None,
|
| 19 |
-
streamer: Optional["BaseStreamer"] = None,
|
| 20 |
-
negative_prompt_ids: Optional[torch.Tensor] = None,
|
| 21 |
-
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 22 |
-
revin: Optional[bool] = True,
|
| 23 |
-
patch_len:Optional[int] = 48,
|
| 24 |
-
stride_len:Optional[int]= 48,
|
| 25 |
-
max_output_length:Optional[int] = 96,
|
| 26 |
-
inference_patch_len: Optional[int] = 48,
|
| 27 |
-
|
| 28 |
-
**kwargs,
|
| 29 |
-
) -> Union[GenerateOutput, torch.Tensor]:
|
| 30 |
-
if len(inputs.shape) != 3:
|
| 31 |
-
raise ValueError('Input shape must be: [batch_size, seq_len, n_vars]')
|
| 32 |
-
|
| 33 |
-
if revin:
|
| 34 |
-
means = inputs.mean(dim=1, keepdim=True)
|
| 35 |
-
stdev = inputs.std(dim=1, keepdim=True, unbiased=False) + 1e-5
|
| 36 |
-
inputs = (inputs - means) / stdev
|
| 37 |
-
|
| 38 |
-
batch_size,seq_len,n_vars = inputs.shape
|
| 39 |
-
num_patch = (max(seq_len, patch_len)-patch_len) // stride_len + 1
|
| 40 |
-
outputs = inputs.view(batch_size, num_patch, patch_len, n_vars)
|
| 41 |
-
outputs = outputs.transpose(2, 3)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
model_inputs = {
|
| 45 |
-
"input" : outputs,
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
outputs = self(**model_inputs) #[batch_size,target_dim,n_vars]
|
| 50 |
-
|
| 51 |
-
outputs = outputs["prediction"]
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
if revin:
|
| 55 |
-
|
| 56 |
-
outputs = (outputs * stdev) + means
|
| 57 |
-
|
| 58 |
-
return outputs
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def _update_model_kwargs_for_generation(
|
| 62 |
-
self,
|
| 63 |
-
outputs: ModelOutput,
|
| 64 |
-
model_kwargs: Dict[str, Any],
|
| 65 |
-
horizon_length: int = 1,
|
| 66 |
-
is_encoder_decoder: bool = False,
|
| 67 |
-
standardize_cache_format: bool = False,
|
| 68 |
-
) -> Dict[str, Any]:
|
| 69 |
-
|
| 70 |
-
return model_kwargs
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|