Harryx2025 commited on
Commit
9dad945
·
verified ·
1 Parent(s): 6c620ed

Update ts_generation_mixin.py

Browse files
Files changed (1) hide show
  1. ts_generation_mixin.py +13 -96
ts_generation_mixin.py CHANGED
@@ -1,30 +1,19 @@
1
- """
2
- Time Series Generation Mixin for PatchMoE
3
-
4
- This module provides generation capabilities specifically designed for time series
5
- forecasting tasks. It extends the standard Transformers GenerationMixin to handle
6
- time series data with proper input/output reshaping and autoregressive generation.
7
- """
8
-
9
- from typing import List, Optional, Union, Callable
10
  import torch
11
  from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
 
12
  from transformers.generation.utils import (
13
  GenerateNonBeamOutput,
 
 
14
  GenerationConfig,
15
  GenerateOutput,
16
  )
 
17
 
18
 
19
- class PatchMoEGenerationMixin(GenerationMixin):
20
- """
21
- Generation mixin class for PatchMoE time series forecasting.
22
-
23
- This class extends the standard Transformers GenerationMixin to provide
24
- specialized generation capabilities for time series data, including proper
25
- handling of multi-channel inputs and autoregressive forecasting.
26
- """
27
-
28
  @torch.no_grad()
29
  def generate(
30
  self,
@@ -43,48 +32,17 @@ class PatchMoEGenerationMixin(GenerationMixin):
43
  **kwargs,
44
  ) -> Union[GenerateOutput, torch.LongTensor]:
45
  """
46
- Generate time series forecasts using the PatchMoE model.
47
-
48
- This method handles the generation of time series forecasts with proper
49
- input preprocessing and output postprocessing for multi-channel data.
50
-
51
- Args:
52
- inputs (torch.Tensor): Input time series data of shape:
53
- - [batch_size, seq_len] for single-channel
54
- - [batch_size, seq_len, channels] for multi-channel
55
- generation_config (GenerationConfig, optional): Generation configuration
56
- logits_processor (LogitsProcessorList, optional): Logits processors
57
- stopping_criteria (StoppingCriteriaList, optional): Stopping criteria
58
- prefix_allowed_tokens_fn (Callable, optional): Prefix token function
59
- synced_gpus (bool, optional): Whether to sync GPUs
60
- assistant_model (PreTrainedModel, optional): Assistant model
61
- streamer (BaseStreamer, optional): Output streamer
62
- negative_prompt_ids (torch.Tensor, optional): Negative prompt IDs
63
- negative_prompt_attention_mask (torch.Tensor, optional): Negative attention mask
64
- revin (bool, optional): Whether to apply RevIN normalization
65
- num_samples (int, optional): Number of samples to generate
66
- **kwargs: Additional keyword arguments
67
-
68
- Returns:
69
- torch.Tensor: Generated forecasts of shape [batch_size, pred_len, channels]
70
-
71
- Raises:
72
- ValueError: If input shape is not supported
73
  """
74
- # Extract input dimensions
75
  batch_size = inputs.shape[0]
76
  length = inputs.shape[1]
77
  channel = 1
78
-
79
- # Handle multi-channel inputs
80
  if len(inputs.shape) == 3:
81
  channel = inputs.shape[2]
82
- # Reshape to [batch_size * channels, seq_len] for processing
83
  inputs = inputs.reshape(batch_size * channel, length)
84
  elif len(inputs.shape) > 3:
85
  raise ValueError("Input shape must be [batch, seq_len, channel] or [batch, seq_len]")
86
 
87
- # Call parent generation method
88
  outputs = super().generate(
89
  inputs=inputs,
90
  generation_config=generation_config,
@@ -99,8 +57,6 @@ class PatchMoEGenerationMixin(GenerationMixin):
99
  revin=revin,
100
  **kwargs,
101
  )
102
-
103
- # Reshape outputs back to [batch_size, pred_len, channels]
104
  pred_len = outputs.shape[1]
105
  outputs = outputs.reshape(batch_size, channel, pred_len)
106
  outputs = outputs.transpose(1, 2).contiguous()
@@ -123,50 +79,11 @@ class PatchMoEGenerationMixin(GenerationMixin):
123
  streamer: Optional["BaseStreamer"] = None,
124
  **model_kwargs,
125
  ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
126
- """
127
- Perform greedy search generation for time series forecasting.
128
-
129
- This method implements greedy decoding specifically for time series data,
130
- where the model generates forecasts autoregressively.
131
-
132
- Args:
133
- input_ids (torch.Tensor): Input time series data
134
- logits_processor (LogitsProcessorList, optional): Logits processors
135
- stopping_criteria (StoppingCriteriaList, optional): Stopping criteria
136
- max_length (int, optional): Maximum generation length
137
- pad_token_id (int, optional): Padding token ID (not used for time series)
138
- eos_token_id (int or List[int], optional): End-of-sequence token ID
139
- output_attentions (bool, optional): Whether to output attentions
140
- output_hidden_states (bool, optional): Whether to output hidden states
141
- output_scores (bool, optional): Whether to output scores
142
- output_logits (bool, optional): Whether to output logits
143
- return_dict_in_generate (bool, optional): Whether to return dict
144
- synced_gpus (bool): Whether to sync GPUs
145
- streamer (BaseStreamer, optional): Output streamer
146
- **model_kwargs: Additional model arguments
147
-
148
- Returns:
149
- torch.Tensor: Generated time series forecasts
150
- """
151
- # Move inputs to model device
152
  input_ids = input_ids.to(self.device)
153
  batch_size, cur_len = input_ids.shape
154
-
155
- # Initialize processors and criteria if not provided
156
- logits_processor = (
157
- logits_processor if logits_processor is not None else LogitsProcessorList()
158
- )
159
- stopping_criteria = (
160
- stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
161
- )
162
-
163
- # Prepare model inputs for generation
164
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
165
-
166
- # Generate forecasts with specified output length
167
- outputs = self(
168
- **model_inputs,
169
- return_dict=True,
170
- max_output_length=stopping_criteria.max_length - cur_len,
171
- )
172
- return outputs
 
1
+ import warnings
2
+ from typing import Any, Dict, List, Optional, Union, Callable
 
 
 
 
 
 
 
3
  import torch
4
  from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
5
+ from transformers.generation import validate_stopping_criteria, EosTokenCriteria
6
  from transformers.generation.utils import (
7
  GenerateNonBeamOutput,
8
+ GenerateEncoderDecoderOutput,
9
+ GenerateDecoderOnlyOutput,
10
  GenerationConfig,
11
  GenerateOutput,
12
  )
13
+ from transformers.utils import ModelOutput
14
 
15
 
16
+ class FalconTSTGenerationMixin(GenerationMixin):
 
 
 
 
 
 
 
 
17
  @torch.no_grad()
18
  def generate(
19
  self,
 
32
  **kwargs,
33
  ) -> Union[GenerateOutput, torch.LongTensor]:
34
  """
35
+ FalconTST generate function。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  """
 
37
  batch_size = inputs.shape[0]
38
  length = inputs.shape[1]
39
  channel = 1
 
 
40
  if len(inputs.shape) == 3:
41
  channel = inputs.shape[2]
 
42
  inputs = inputs.reshape(batch_size * channel, length)
43
  elif len(inputs.shape) > 3:
44
  raise ValueError("Input shape must be [batch, seq_len, channel] or [batch, seq_len]")
45
 
 
46
  outputs = super().generate(
47
  inputs=inputs,
48
  generation_config=generation_config,
 
57
  revin=revin,
58
  **kwargs,
59
  )
 
 
60
  pred_len = outputs.shape[1]
61
  outputs = outputs.reshape(batch_size, channel, pred_len)
62
  outputs = outputs.transpose(1, 2).contiguous()
 
79
  streamer: Optional["BaseStreamer"] = None,
80
  **model_kwargs,
81
  ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  input_ids = input_ids.to(self.device)
83
  batch_size, cur_len = input_ids.shape
84
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
85
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
 
 
 
 
 
 
 
 
86
  model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
87
+ # stopping_criteria.max_length = input_len + pred_len
88
+ outputs = self(**model_inputs, return_dict=True, max_output_length=stopping_criteria.max_length-cur_len)
89
+ return outputs