pchen182224 commited on
Commit
c63b8cc
·
verified ·
1 Parent(s): 417981b

Upload 5 files

Browse files
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "act": "gelu",
3
+ "attn_dropout": 0.4,
4
+ "c_in": 1,
5
+ "context_points": 528,
6
+ "d_ff": 512,
7
+ "d_layers": 3,
8
+ "d_model": 256,
9
+ "dropout": 0.0,
10
+ "e_layers": 3,
11
+ "head_dropout": 0,
12
+ "head_type": "prediction",
13
+ "initializer_range": 0.02,
14
+ "mask_mode": "patch",
15
+ "mask_nums": 3,
16
+ "model_type": "LightGTS",
17
+ "n_heads": 16,
18
+ "num_patch": 11,
19
+ "patch_len": 48,
20
+ "shared_embedding": true,
21
+ "stride": 48,
22
+ "target_dim": 192,
23
+ "transformers_version": "4.30.2"
24
+ }
configuration_LightGTS.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import Optional
3
+ import math
4
+
5
+
6
+ class LightGTSConfig(PretrainedConfig):
7
+
8
+ model_type = "LightGTS"
9
+
10
+
11
+ def __init__(self, context_points:int = 512, c_in:int = 1, target_dim:int = 96, patch_len:int = 32, stride:int = 32, mask_mode:str = 'patch',mask_nums:int = 3,
12
+ e_layers:int=3, d_layers:int=3, d_model=256, n_heads=16, shared_embedding=True, d_ff:int=512,
13
+ norm:str='BatchNorm', attn_dropout:float=0.4, dropout:float=0., act:str="gelu",
14
+ res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False,
15
+ pe:str='sincos', learn_pe:bool=False, head_dropout = 0,
16
+ head_type = "prediction", individual = False,
17
+ y_range:Optional[tuple]=None, verbose:bool=False, **kwargs):
18
+
19
+ self.context_points = context_points
20
+ self.c_in = c_in
21
+ self.target_dim = target_dim
22
+ self.patch_len = patch_len
23
+ self.stride = stride
24
+ self.num_patch = (max(self.context_points, self.patch_len)-self.patch_len) // self.stride + 1
25
+ self.mask_mode = mask_mode
26
+ self.mask_nums = mask_nums
27
+ self.e_layers = e_layers
28
+ self.d_layers = d_layers
29
+ self.d_model = d_model
30
+ self.n_heads = n_heads
31
+ self.shared_embedding = shared_embedding
32
+ self.d_ff = d_ff
33
+ self.dropout = dropout
34
+ self.attn_dropout = attn_dropout
35
+ self.head_dropout = head_dropout
36
+ self.act = act
37
+ self.head_type = head_type
38
+ self.initializer_range = 0.02
39
+ super().__init__(**kwargs)
modeling_LightGTS.py ADDED
@@ -0,0 +1,862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from configuration_LightGTS import LightGTSConfig
3
+ from ts_generation_mixin import TSGenerationMixin
4
+ import torch
5
+ from torch import nn
6
+ from torch import Tensor
7
+ from typing import Callable, Optional
8
+ import math
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+
12
+
13
+ class LightGTSPreTrainedModel(PreTrainedModel):
14
+ config_class = LightGTSConfig
15
+ base_model_prefix = "model"
16
+ supports_gradient_checkpointing = True
17
+ _no_split_modules = ["TSTEncoderLayer"]
18
+ _skip_keys_device_placement = "past_key_values"
19
+ _supports_flash_attn_2 = True
20
+ _supports_sdpa = False
21
+ _supports_cache_class = True
22
+
23
+
24
+ def _init_weights(self, module):
25
+ std = self.config.initializer_range
26
+ if isinstance(module, torch.nn.Linear):
27
+ module.weight.data.normal_(mean=0.0, std=std)
28
+ if module.bias is not None:
29
+ module.bias.data.zero_()
30
+ elif isinstance(module, torch.nn.Embedding):
31
+ module.weight.data.normal_(mean=0.0, std=std)
32
+ if module.padding_idx is not None:
33
+ module.weight.data[module.padding_idx].zero_()
34
+
35
+
36
+ class LightGTSForPrediction(LightGTSPreTrainedModel, TSGenerationMixin):
37
+ def __init__(self, config: LightGTSConfig):
38
+ super().__init__(config)
39
+ self.config = config
40
+ self.model = LightGTSForZeroShot(c_in=config.c_in,
41
+ target_dim=config.target_dim,
42
+ patch_len=config.patch_len,
43
+ stride=config.stride,
44
+ num_patch=config.num_patch,
45
+ e_layers=config.e_layers,
46
+ d_layers=config.d_layers,
47
+ n_heads=config.n_heads,
48
+ d_model=config.d_model,
49
+ shared_embedding=True,
50
+ d_ff=config.d_ff,
51
+ dropout=config.dropout,
52
+ attn_dropout=config.attn_dropout,
53
+ head_dropout=config.head_dropout,
54
+ act='relu',
55
+ head_type=config.head_type,
56
+ res_attention=False,
57
+ learn_pe=False
58
+ )
59
+
60
+ def forward(self, input, labels=None):
61
+
62
+
63
+
64
+ outputs = self.model(input)
65
+
66
+
67
+ loss = None
68
+ if labels is not None:
69
+
70
+ if outputs.shape != labels.shape:
71
+
72
+ outputs = outputs.view(labels.shape)
73
+ loss = self.loss_fn(outputs, labels)
74
+
75
+
76
+ return {"prediction": outputs, "loss": loss}
77
+
78
+ class LightGTS(nn.Module):
79
+ """
80
+ Output dimension:
81
+ [bs x target_dim x nvars] for prediction
82
+ [bs x target_dim] for regression
83
+ [bs x target_dim] for classification
84
+ [bs x num_patch x n_vars x patch_len] for pretrain
85
+ """
86
+ def __init__(self, c_in:int, target_dim:int, patch_len:int, stride:int, num_patch:int, mask_mode:str = 'patch',mask_nums:int = 3,
87
+ e_layers:int=3, d_layers:int=3, d_model=128, n_heads=16, shared_embedding=True, d_ff:int=256,
88
+ norm:str='BatchNorm', attn_dropout:float=0.4, dropout:float=0., act:str="gelu",
89
+ res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False,
90
+ pe:str='sincos', learn_pe:bool=False, head_dropout = 0,
91
+ head_type = "prediction", individual = False,
92
+ y_range:Optional[tuple]=None, verbose:bool=False, **kwargs):
93
+
94
+ super().__init__()
95
+ assert head_type in ['pretrain', 'prediction', 'regression', 'classification'], 'head type should be either pretrain, prediction, or regression'
96
+
97
+ # Basic
98
+ self.num_patch = num_patch
99
+ self.target_dim=target_dim
100
+ self.out_patch_num = math.ceil(target_dim / patch_len)
101
+ self.target_patch_len = 48
102
+
103
+ # Embedding
104
+ self.embedding = nn.Linear(self.target_patch_len, d_model)
105
+ self.cls_embedding = nn.Parameter(torch.randn(1, 1, 1, d_model),requires_grad=True)
106
+
107
+ # Encoder
108
+ self.encoder = TSTEncoder(d_model, n_heads, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout,
109
+ pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=e_layers,
110
+ store_attn=store_attn)
111
+
112
+ # Decoder
113
+ self.decoder = Decoder(d_layers, patch_len=patch_len, d_model=d_model, n_heads=n_heads, d_ff=d_ff,attn_dropout= attn_dropout, dropout=dropout)
114
+
115
+ # Head
116
+ self.n_vars = c_in
117
+ self.head_type = head_type
118
+ self.mask_mode = mask_mode
119
+ self.mask_nums = mask_nums
120
+ self.d_model = d_model
121
+ self.patch_len = patch_len
122
+
123
+
124
+ if head_type == "pretrain":
125
+ self.head = PretrainHead(d_model, patch_len, head_dropout) # custom head passed as a partial func with all its kwargs
126
+ elif head_type == "prediction":
127
+ self.head = decoder_PredictHead(d_model, self.patch_len, self.target_patch_len, head_dropout)
128
+
129
+
130
+ # def get_dynamic_weights(self, n_preds):
131
+ # """
132
+ # Generate dynamic weights for the replicated tokens. This example uses a linearly decreasing weight.
133
+ # You can modify this to use other schemes like exponential decay, sine/cosine, etc.
134
+ # """
135
+ # # Linearly decreasing weights from 1.0 to 0.5 (as an example)
136
+ # weights = torch.linspace(1.0, 0.5, n_preds)
137
+ # return weights
138
+
139
+ def get_dynamic_weights(self, n_preds, decay_rate=0.5):
140
+ """
141
+ Generate dynamic weights for the replicated tokens using an exponential decay scheme.
142
+
143
+ Args:
144
+ - n_preds (int): Number of predictions to generate weights for.
145
+ - decay_rate (float): The base of the exponential decay. Lower values decay faster (default: 0.9).
146
+
147
+ Returns:
148
+ - torch.Tensor: A tensor of weights with exponential decay.
149
+ """
150
+ # Exponential decay weights
151
+ weights = decay_rate ** torch.arange(n_preds)
152
+ return weights
153
+
154
+ def decoder_predict(self, bs, n_vars, dec_cross):
155
+ """
156
+ dec_cross: tensor [bs x n_vars x num_patch x d_model]
157
+ """
158
+ # dec_in = self.decoder_embedding.e xpand(bs, self.n_vars, self.out_patch_num, -1)
159
+ # dec_in = self.embedding(self.decoder_len).expand(bs, -1, -1, -1)
160
+ # dec_in = self.decoder_embedding.expand(bs, n_vars, self.out_patch_num, -1)
161
+ # dec_in = dec_cross.mean(2).unsqueeze(2).expand(-1,-1,self.out_patch_num,-1)
162
+ dec_in = dec_cross[:,:,-1,:].unsqueeze(2).expand(-1,-1,self.out_patch_num,-1)
163
+ # dec_in = torch.ones_like(dec_in)
164
+ weights = self.get_dynamic_weights(self.out_patch_num).to(dec_in.device)
165
+ dec_in = dec_in * weights.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
166
+ # dec_in = torch.cat((dec_in, self.sep_tokens), dim=2)
167
+
168
+ # dec_in = dec_cross[:,:,-self.out_patch_num:,:]
169
+ # dec_in = torch.ones([bs, n_vars, self.out_patch_num, self.d_model]).to(dec_cross.device)
170
+ # dec_in = dec_in + self.pos[-self.out_patch_num:,:]
171
+ decoder_output = self.decoder(dec_in, dec_cross)
172
+ decoder_output = decoder_output.transpose(2,3)
173
+
174
+ return decoder_output
175
+
176
+
177
+ def forward(self, z):
178
+ """
179
+ z: tensor [bs x num_patch x n_vars x patch_len]
180
+ """
181
+
182
+ bs, num_patch, n_vars, patch_len = z.shape
183
+
184
+ # tokenizer
185
+ cls_tokens = self.cls_embedding.expand(bs, n_vars, -1, -1)
186
+
187
+ embedding = nn.Linear(patch_len, self.d_model, bias=False)
188
+ embedding.weight.data = resample_patchemb(old=self.embedding.weight.data, new_patch_len=self.patch_len)
189
+
190
+ z = embedding(z).permute(0,2,1,3) # [bs x n_vars x num_patch x d_model]
191
+
192
+
193
+ z = torch.cat((cls_tokens, z), dim=2) # [bs x n_vars x (1 + num_patch) x d_model]
194
+
195
+ # encoder
196
+ z = torch.reshape(z, (-1, 1 + num_patch, self.d_model)) # [bs*n_vars x num_patch x d_model]
197
+ z = self.encoder(z)
198
+ z = torch.reshape(z, (-1, n_vars, 1 + num_patch, self.d_model)) # [bs, n_vars x num_patch x d_model]
199
+
200
+ # decoder
201
+ z = self.decoder_predict(bs, n_vars, z[:,:,:,:])
202
+
203
+ # predict
204
+ z = self.head(z[:,:,:,:])
205
+ z = z[:,:self.target_dim, :]
206
+
207
+
208
+ # z: [bs x target_dim x nvars] for prediction
209
+ # [bs x target_dim] for regression
210
+ # [bs x target_dim] for classification
211
+ # [bs x num_patch x n_vars x patch_len] for pretrain
212
+ return z
213
+
214
+ class LightGTSForZeroShot(nn.Module):
215
+ """
216
+ Output dimension:
217
+ [bs x target_dim x nvars] for prediction
218
+ [bs x target_dim] for regression
219
+ [bs x target_dim] for classification
220
+ [bs x num_patch x n_vars x patch_len] for pretrain
221
+ """
222
+ def __init__(self, c_in:int, target_dim:int, patch_len:int, stride:int, num_patch:int, mask_mode:str = 'patch',mask_nums:int = 3,
223
+ e_layers:int=3, d_layers:int=3, d_model=128, n_heads=16, shared_embedding=True, d_ff:int=256,
224
+ norm:str='BatchNorm', attn_dropout:float=0.4, dropout:float=0., act:str="gelu",
225
+ res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False,
226
+ pe:str='sincos', learn_pe:bool=False, head_dropout = 0,
227
+ head_type = "prediction", individual = False,
228
+ y_range:Optional[tuple]=None, verbose:bool=False, **kwargs):
229
+
230
+ super().__init__()
231
+ assert head_type in ['pretrain', 'prediction', 'regression', 'classification'], 'head type should be either pretrain, prediction, or regression'
232
+
233
+ # Basic
234
+ self.num_patch = num_patch
235
+ self.target_dim=target_dim
236
+ self.out_patch_num = math.ceil(target_dim / patch_len)
237
+ self.target_patch_len = 48
238
+ # Embedding
239
+ self.embedding = nn.Linear(self.target_patch_len, d_model)
240
+ # self.decoder_embedding = nn.Parameter(torch.randn(1, 1,1, d_model),requires_grad=True)
241
+ self.cls_embedding = nn.Parameter(torch.randn(1, 1, 1, d_model),requires_grad=True)
242
+ # self.sep_embedding = nn.Parameter(torch.randn(1, 1, 1, d_model),requires_grad=True)
243
+
244
+ # Position Embedding
245
+ # self.pos = positional_encoding(pe, learn_pe, 1 + num_patch + self.out_patch_num, d_model)
246
+ # self.drop_out = nn.Dropout(dropout)
247
+
248
+ # Encoder
249
+ self.encoder = TSTEncoder(d_model, n_heads, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout,
250
+ pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=e_layers,
251
+ store_attn=store_attn)
252
+
253
+ # Decoder
254
+ self.decoder = Decoder(d_layers, patch_len=patch_len, d_model=d_model, n_heads=n_heads, d_ff=d_ff,attn_dropout= attn_dropout, dropout=dropout)
255
+
256
+ # Head
257
+ self.n_vars = c_in
258
+ self.head_type = head_type
259
+ self.mask_mode = mask_mode
260
+ self.mask_nums = mask_nums
261
+ self.d_model = d_model
262
+ self.patch_len = patch_len
263
+
264
+
265
+
266
+
267
+ if head_type == "pretrain":
268
+ self.head = PretrainHead(d_model, patch_len, head_dropout) # custom head passed as a partial func with all its kwargs
269
+ elif head_type == "prediction":
270
+ self.head = decoder_PredictHead(d_model, self.patch_len, self.target_patch_len, head_dropout)
271
+
272
+ # self.apply(self._init_weights)
273
+
274
+ # def get_dynamic_weights(self, n_preds):
275
+ # """
276
+ # Generate dynamic weights for the replicated tokens. This example uses a linearly decreasing weight.
277
+ # You can modify this to use other schemes like exponential decay, sine/cosine, etc.
278
+ # """
279
+ # # Linearly decreasing weights from 1.0 to 0.5 (as an example)
280
+ # weights = torch.linspace(1.0, 0.5, n_preds)
281
+ # return weights
282
+
283
+ def get_dynamic_weights(self, n_preds, decay_rate=0.5):
284
+ """
285
+ Generate dynamic weights for the replicated tokens using an exponential decay scheme.
286
+
287
+ Args:
288
+ - n_preds (int): Number of predictions to generate weights for.
289
+ - decay_rate (float): The base of the exponential decay. Lower values decay faster (default: 0.9).
290
+
291
+ Returns:
292
+ - torch.Tensor: A tensor of weights with exponential decay.
293
+ """
294
+ # Exponential decay weights
295
+ weights = decay_rate ** torch.arange(n_preds)
296
+ return weights
297
+
298
+ def decoder_predict(self, bs, n_vars, dec_cross):
299
+ """
300
+ dec_cross: tensor [bs x n_vars x num_patch x d_model]
301
+ """
302
+ # dec_in = self.decoder_embedding.expand(bs, self.n_vars, self.out_patch_num, -1)
303
+ # dec_in = self.embedding(self.decoder_len).expand(bs, -1, -1, -1)
304
+ # dec_in = self.decoder_embedding.expand(bs, n_vars, self.out_patch_num, -1)
305
+ # dec_in = dec_cross.mean(2).unsqueeze(2).expand(-1,-1,self.out_patch_num,-1)
306
+ dec_in = dec_cross[:,:,-1,:].unsqueeze(2).expand(-1,-1,self.out_patch_num,-1)
307
+ weights = self.get_dynamic_weights(self.out_patch_num).to(dec_in.device)
308
+ dec_in = dec_in * weights.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
309
+ # dec_in = torch.cat((dec_in, self.sep_tokens), dim=2)
310
+
311
+ # dec_in = dec_cross[:,:,-self.out_patch_num:,:]
312
+ # dec_in = torch.ones([bs, n_vars, self.out_patch_num, self.d_model]).to(dec_cross.device)
313
+ # dec_in = dec_in + self.pos[-self.out_patch_num:,:]
314
+ decoder_output = self.decoder(dec_in, dec_cross)
315
+ decoder_output = decoder_output.transpose(2,3)
316
+
317
+ return decoder_output
318
+
319
+
320
+ def forward(self, z):
321
+ """
322
+ z: tensor [bs x num_patch x n_vars x patch_len]
323
+ """
324
+ bs, num_patch, n_vars, patch_len = z.shape
325
+ z = resize(z, target_patch_len=self.target_patch_len)
326
+
327
+ # tokenizer
328
+ cls_tokens = self.cls_embedding.expand(bs, n_vars, -1, -1)
329
+ z = self.embedding(z).permute(0,2,1,3) # [bs x n_vars x num_patch x d_model]
330
+ z = torch.cat((cls_tokens, z), dim=2) # [bs x n_vars x (1 + num_patch) x d_model]
331
+ # z = self.drop_out(z + self.pos[:1 + self.num_patch, :])
332
+
333
+ # encoder
334
+ z = torch.reshape(z, (-1, 1 + num_patch, self.d_model)) # [bs*n_vars x num_patch x d_model]
335
+ z = self.encoder(z)
336
+ z = torch.reshape(z, (-1, n_vars, 1 + num_patch, self.d_model)) # [bs, n_vars x num_patch x d_model]
337
+
338
+ # decoder
339
+ z = self.decoder_predict(bs, n_vars, z[:,:,:,:])
340
+
341
+ # predict
342
+ z = self.head(z[:,:,:,:])
343
+ z = z[:,:self.target_dim, :]
344
+
345
+
346
+ # z: [bs x target_dim x nvars] for prediction
347
+ # [bs x target_dim] for regression
348
+ # [bs x target_dim] for classification
349
+ # [bs x num_patch x n_vars x patch_len] for pretrain
350
+ return z
351
+
352
+ def resize(x, target_patch_len):
353
+ '''
354
+ x: tensor [bs x num_patch x n_vars x patch_len]]
355
+ '''
356
+ bs, num_patch, n_vars, patch_len = x.shape
357
+ x = x.reshape(bs*num_patch, n_vars, patch_len)
358
+ x = F.interpolate(x, size=target_patch_len, mode='linear', align_corners=False)
359
+ return x.reshape(bs, num_patch, n_vars, target_patch_len)
360
+
361
+ class TSTEncoder(nn.Module):
362
+ def __init__(self, d_model, n_heads, d_ff=None,
363
+ norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',
364
+ res_attention=False, n_layers=1, pre_norm=False, store_attn=False):
365
+ super().__init__()
366
+
367
+ self.layers = nn.ModuleList([TSTEncoderLayer(d_model, n_heads=n_heads, d_ff=d_ff, norm=norm,
368
+ attn_dropout=attn_dropout, dropout=dropout,
369
+ activation=activation, res_attention=res_attention,
370
+ pre_norm=pre_norm, store_attn=store_attn) for i in range(n_layers)])
371
+ self.res_attention = res_attention
372
+
373
+ def forward(self, src:Tensor):
374
+ """
375
+ src: tensor [bs x q_len x d_model]
376
+ """
377
+ output = src
378
+ scores = None
379
+ if self.res_attention:
380
+ for mod in self.layers: output, scores = mod(output, prev=scores)
381
+ return output
382
+ else:
383
+ for mod in self.layers: output = mod(output)
384
+ return output
385
+
386
+ class TSTEncoderLayer(nn.Module):
387
+ def __init__(self, d_model, n_heads, d_ff=256, store_attn=False,
388
+ norm='LayerNorm', attn_dropout=0, dropout=0., bias=True,
389
+ activation="gelu", res_attention=False, pre_norm=False):
390
+ super().__init__()
391
+ assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
392
+ d_k = d_model // n_heads
393
+ d_v = d_model // n_heads
394
+
395
+ # Multi-Head attention
396
+ self.res_attention = res_attention
397
+ self.self_attn = MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention)
398
+
399
+ # Add & Norm
400
+ self.dropout_attn = nn.Dropout(dropout)
401
+ if "batch" in norm.lower():
402
+ self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
403
+ else:
404
+ self.norm_attn = nn.LayerNorm(d_model)
405
+
406
+ # Position-wise Feed-Forward
407
+ self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias),
408
+ get_activation_fn(activation),
409
+ nn.Dropout(dropout),
410
+ nn.Linear(d_ff, d_model, bias=bias))
411
+
412
+ # Add & Norm
413
+ self.dropout_ffn = nn.Dropout(dropout)
414
+ if "batch" in norm.lower():
415
+ self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
416
+ else:
417
+ self.norm_ffn = nn.LayerNorm(d_model)
418
+
419
+ self.pre_norm = pre_norm
420
+ self.store_attn = store_attn
421
+
422
+ # # se block
423
+ # self.SE = SE_Block(inchannel=7)
424
+
425
+
426
+ def forward(self, src:Tensor, prev:Optional[Tensor]=None):
427
+ """
428
+ src: tensor [bs x q_len x d_model]
429
+ """
430
+ # Multi-Head attention sublayer
431
+ if self.pre_norm:
432
+ src = self.norm_attn(src)
433
+ ## Multi-Head attention
434
+ if self.res_attention:
435
+ src2, attn, scores = self.self_attn(src, src, src, prev)
436
+ else:
437
+ # attention_mask = causal_attention_mask(src.shape[1]).to(src.device)
438
+ # src2, attn = self.self_attn(src, src, src, attn_mask=attention_mask)
439
+ src2, attn = self.self_attn(src, src, src)
440
+ if self.store_attn:
441
+ self.attn = attn
442
+
443
+ # total, num_patch, d_model = src2.size()
444
+ # bs = int(total/7)
445
+
446
+ # src2 = self.SE(src2.reshape(bs, 7, num_patch, -1)).reshape(total, num_patch, -1)
447
+
448
+
449
+ ## Add & Norm
450
+ src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout
451
+ if not self.pre_norm:
452
+ src = self.norm_attn(src)
453
+
454
+ # Feed-forward sublayer
455
+ if self.pre_norm:
456
+ src = self.norm_ffn(src)
457
+ ## Position-wise Feed-Forward
458
+ src2 = self.ff(src)
459
+ ## Add & Norm
460
+ src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout
461
+ if not self.pre_norm:
462
+ src = self.norm_ffn(src)
463
+
464
+ if self.res_attention:
465
+ return src, scores
466
+ else:
467
+ return src
468
+
469
+
470
+ class Decoder(nn.Module):
471
+ def __init__(self, d_layers, patch_len, d_model, n_heads, d_ff=None, attn_dropout=0.2, dropout=0.1):
472
+ super(Decoder, self).__init__()
473
+
474
+ self.decoder_layers = nn.ModuleList()
475
+ for i in range(d_layers):
476
+ self.decoder_layers.append(DecoderLayer(patch_len, d_model, n_heads, d_ff, attn_dropout, dropout))
477
+
478
+ def forward(self, x, cross):
479
+ output = x
480
+ for layer in self.decoder_layers:
481
+ output = layer(output, cross)
482
+ return output
483
+
484
+
485
+ class DecoderLayer(nn.Module):
486
+ def __init__(self, patch_len, d_model, n_heads, d_ff=None, attn_dropout = 0.2, dropout=0.5, norm="BatchNorm"):
487
+ super(DecoderLayer, self).__init__()
488
+ self.self_attention = MultiheadAttention(d_model, n_heads, res_attention=False, attn_dropout=attn_dropout)
489
+ self.cross_attention = MultiheadAttention(d_model, n_heads, attn_dropout=attn_dropout, rope_type=True)
490
+ # self.pos_embed = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1, groups=d_model)
491
+
492
+ if 'batch' in norm.lower():
493
+ self.norm1 = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
494
+ self.norm2 = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
495
+ self.norm3 = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
496
+ else:
497
+ self.norm1 = nn.LayerNorm(d_model)
498
+ self.norm2 = nn.LayerNorm(d_model)
499
+ self.norm3 = nn.LayerNorm(d_model)
500
+
501
+
502
+ self.dropout = nn.Dropout(dropout)
503
+
504
+ self.MLP1 = CMlp(in_features = d_model, hidden_features = d_ff, out_features = d_model, drop=dropout)
505
+
506
+
507
+
508
+ def forward(self, x, cross):
509
+ batch, n_vars, num_patch, d_model = x.shape
510
+ x = x.reshape(batch*n_vars, num_patch, d_model)
511
+
512
+ # x = x.permute(0,2,1)
513
+ # x = x + self.pos_embed(x)
514
+ # x = x.permute(0,2,1)
515
+
516
+ cross = cross.reshape(batch*n_vars, -1, d_model)
517
+
518
+ attention_mask = causal_attention_mask(num_patch).to(x.device)
519
+ x_attn , _= self.self_attention(x, attn_mask=attention_mask)
520
+ x_attn = self.norm1(x_attn) + x
521
+
522
+ x_cross , _ = self.cross_attention(x_attn, cross, cross)
523
+ x_cross = self.dropout(self.norm2(x_cross)) + x_attn
524
+
525
+ x_ff = self.MLP1(x_cross)
526
+ x_ff = self.norm3(x_ff) + x_cross
527
+
528
+ x_ff = x_ff.reshape(batch, n_vars, num_patch, d_model)
529
+
530
+ return x_ff
531
+
532
+ def causal_attention_mask(seq_length):
533
+ """
534
+ 创建一个因果注意力掩码。掩码中的每个位置 (i, j)
535
+ 表示在计算第i个位置的attention时, 第j个位置是否可以被看见。
536
+ 如果j <= i, 这个位置被设为1(可见), 否则设为0(不可见)。
537
+
538
+ Args:
539
+ seq_length (int): 序列的长度
540
+
541
+ Returns:
542
+ torch.Tensor: 因果注意力掩码,大小为 (seq_length, seq_length)
543
+ """
544
+ mask = torch.triu(torch.ones(seq_length, seq_length) * float('-inf'), diagonal=1)
545
+ return mask
546
+
547
+ class CMlp(nn.Module):
548
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
549
+ super().__init__()
550
+ out_features = out_features or in_features
551
+ hidden_features = hidden_features or in_features
552
+ self.fc1 = nn.Conv1d(in_features, hidden_features, 1)
553
+ self.act = act_layer()
554
+ self.fc2 = nn.Conv1d(hidden_features, out_features, 1)
555
+ self.drop = nn.Dropout(drop)
556
+
557
+ def forward(self, x):
558
+ x = x.permute(0,2,1)
559
+ x = self.fc1(x)
560
+ x = self.act(x)
561
+ x = self.drop(x)
562
+ x = self.fc2(x)
563
+ x = self.drop(x)
564
+ x = x.permute(0,2,1)
565
+ return x
566
+
567
+ class Transpose(nn.Module):
568
+ def __init__(self, *dims, contiguous=False):
569
+ super().__init__()
570
+ self.dims, self.contiguous = dims, contiguous
571
+ def forward(self, x):
572
+ if self.contiguous: return x.transpose(*self.dims).contiguous()
573
+ else: return x.transpose(*self.dims)
574
+
575
+
576
+ class MultiheadAttention(nn.Module):
577
+ def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False, rope_type=False):
578
+ """Multi Head Attention Layer
579
+ Input shape:
580
+ Q: [batch_size (bs) x max_q_len x d_model]
581
+ K, V: [batch_size (bs) x q_len x d_model]
582
+ mask: [q_len x q_len]
583
+ """
584
+ super().__init__()
585
+ d_k = d_model // n_heads if d_k is None else d_k
586
+ d_v = d_model // n_heads if d_v is None else d_v
587
+
588
+ self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v
589
+
590
+ self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
591
+ self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
592
+ self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)
593
+
594
+ # Scaled Dot-Product Attention (multiple heads)
595
+ self.res_attention = res_attention
596
+ self.sdp_attn = ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa, rope_type=rope_type)
597
+
598
+ # Poject output
599
+ self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout))
600
+
601
+
602
+
603
+
604
+ def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,
605
+ key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
606
+
607
+ bs = Q.size(0)
608
+ if K is None: K = Q
609
+ if V is None: V = Q
610
+
611
+ # Linear (+ split in multiple heads)
612
+ q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2) # q_s : [bs x n_heads x max_q_len x d_k]
613
+ k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)
614
+ v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2) # v_s : [bs x n_heads x q_len x d_v]
615
+
616
+ # Apply Scaled Dot-Product Attention (multiple heads)
617
+ if self.res_attention:
618
+ output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
619
+ else:
620
+ output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
621
+ # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]
622
+
623
+ # back to the original inputs dimensions
624
+ output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]
625
+ output = self.to_out(output)
626
+
627
+ if self.res_attention: return output, attn_weights, attn_scores
628
+ else: return output, attn_weights
629
+
630
+ class ScaledDotProductAttention(nn.Module):
631
+ r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer
632
+ (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets
633
+ by Lee et al, 2021)"""
634
+
635
+ def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False, rope_type=False):
636
+ super().__init__()
637
+ self.attn_dropout = nn.Dropout(attn_dropout)
638
+ self.res_attention = res_attention
639
+ head_dim = d_model // n_heads
640
+ self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)
641
+ self.lsa = lsa
642
+ self.rope_type = rope_type
643
+
644
+ def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
645
+ '''
646
+ Input shape:
647
+ q : [bs x n_heads x max_q_len x d_k]
648
+ k : [bs x n_heads x d_k x seq_len]
649
+ v : [bs x n_heads x seq_len x d_v]
650
+ prev : [bs x n_heads x q_len x seq_len]
651
+ key_padding_mask: [bs x seq_len]
652
+ attn_mask : [1 x seq_len x seq_len]
653
+ Output shape:
654
+ output: [bs x n_heads x q_len x d_v]
655
+ attn : [bs x n_heads x q_len x seq_len]
656
+ scores : [bs x n_heads x q_len x seq_len]
657
+ '''
658
+ # using RoPE
659
+ if self.rope_type:
660
+ q, k = RoPE_decoder(q, k.permute(0,1,3,2))
661
+ else:
662
+ q, k = RoPE(q, k.permute(0,1,3,2))
663
+ k = k.permute(0,1,3,2)
664
+
665
+ # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence
666
+ attn_scores = torch.matmul(q, k) * self.scale # attn_scores : [bs x n_heads x max_q_len x q_len]
667
+
668
+ # Add pre-softmax attention scores from the previous layer (optional)
669
+ if prev is not None: attn_scores = attn_scores + prev
670
+
671
+ # Attention mask (optional)
672
+ if attn_mask is not None: # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len
673
+ if attn_mask.dtype == torch.bool:
674
+ attn_scores.masked_fill_(attn_mask, -np.inf)
675
+ else:
676
+ attn_scores += attn_mask
677
+
678
+ # Key padding mask (optional)
679
+ if key_padding_mask is not None: # mask with shape [bs x q_len] (only when max_w_len == q_len)
680
+ attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)
681
+
682
+ # normalize the attention weights
683
+ attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights : [bs x n_heads x max_q_len x q_len]
684
+ attn_weights = self.attn_dropout(attn_weights)
685
+
686
+ # compute the new values given the attention weights
687
+ output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v]
688
+
689
+ if self.res_attention: return output, attn_weights, attn_scores
690
+ else: return output, attn_weights
691
+
692
+ def RoPE(q, k):
693
+ # q,k: (bs, head, max_len, output_dim)
694
+ batch_size = q.shape[0]
695
+ nums_head = q.shape[1]
696
+ max_len = q.shape[2]
697
+ output_dim = q.shape[-1]
698
+
699
+ # (bs, head, max_len, output_dim)
700
+ pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device, factor=1)
701
+
702
+ # cos_pos,sin_pos: (bs, head, max_len, output_dim)
703
+ # 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
704
+ cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 将奇数列信息抽取出来也就是cos 拿出来并复制
705
+ sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 将偶数列信息抽取出来也就是sin 拿出来并复制
706
+
707
+ # q,k: (bs, head, max_len, output_dim)
708
+ q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
709
+ q2 = q2.reshape(q.shape) # reshape后就是正负交替了
710
+
711
+
712
+ # 更新qw, *对应位置相乘
713
+ q = q * cos_pos + q2 * sin_pos
714
+
715
+ k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
716
+ k2 = k2.reshape(k.shape)
717
+ # 更新kw, *对应位置相乘
718
+ k = k * cos_pos + k2 * sin_pos
719
+
720
+ return q, k
721
+
722
+
723
+ def RoPE_decoder(q, k):
724
+ # q,k: (bs, head, max_len, output_dim)
725
+ batch_size = q.shape[0]
726
+ nums_head = q.shape[1]
727
+ q_max_len = q.shape[2]
728
+ k_max_len = k.shape[2]
729
+ output_dim = q.shape[-1]
730
+
731
+ # (bs, head, max_len, output_dim)
732
+ pos_emb = sinusoidal_position_embedding(batch_size, nums_head, k_max_len + q_max_len, output_dim, q.device, factor=1)
733
+
734
+
735
+ # cos_pos,sin_pos: (bs, head, max_len, output_dim)
736
+ # 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
737
+ cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 将奇数列信息抽取出来也就是cos 拿出来并复制
738
+ sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 将偶数列信息抽取出来也就是sin 拿出来并复制
739
+
740
+ # q,k: (bs, head, max_len, output_dim)
741
+ q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
742
+ q2 = q2.reshape(q.shape) # reshape后就是正负交替了
743
+
744
+
745
+ # 更新qw, *对应位置相乘
746
+ q = q * cos_pos[:,:,-q_max_len:,:] + q2 * sin_pos[:,:,-q_max_len:,:]
747
+
748
+
749
+ k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
750
+ k2 = k2.reshape(k.shape)
751
+ # 更新kw, *对应位置相乘
752
+ k = k * cos_pos[:,:,:k_max_len,:] + k2 * sin_pos[:,:,:k_max_len,:]
753
+ return q, k
754
+
755
+ def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device, factor=1.0):
756
+ # (max_len * factor, 1)
757
+ position = torch.arange(0, max_len * factor, 1 / factor, dtype=torch.float).unsqueeze(-1)
758
+ # (output_dim//2)
759
+ ids = torch.arange(0, output_dim // 2, dtype=torch.float) # i 范围是 [0, d/2]
760
+ theta = torch.pow(10000, -2 * ids / output_dim)
761
+
762
+ # (max_len * factor, output_dim//2)
763
+ embeddings = position * theta
764
+
765
+ # (max_len * factor, output_dim//2, 2)
766
+ embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
767
+
768
+ # (bs, head, max_len * factor, output_dim//2, 2)
769
+ embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))
770
+
771
+ # (bs, head, max_len * factor, output_dim)
772
+ embeddings = torch.reshape(embeddings, (batch_size, nums_head, -1, output_dim))
773
+ embeddings = embeddings.to(device)
774
+
775
+ # 如果 factor > 1, 使用插值位置来生成更细粒度的嵌入
776
+ if factor > 1.0:
777
+ interpolation_indices = torch.linspace(0, embeddings.shape[2] - 1, max_len).long()
778
+ embeddings = embeddings[:, :, interpolation_indices, :]
779
+
780
+ return embeddings
781
+
782
+ class PretrainHead(nn.Module):
783
+ def __init__(self, d_model, patch_len, dropout):
784
+ super().__init__()
785
+ self.dropout = nn.Dropout(dropout)
786
+ self.linear = nn.Linear(d_model, patch_len)
787
+
788
+ def forward(self, x):
789
+ """
790
+ x: tensor [bs x nvars x d_model x num_patch]
791
+ output: tensor [bs x nvars x num_patch x patch_len]
792
+ """
793
+
794
+ x = x.transpose(2,3) # [bs x nvars x num_patch x d_model]
795
+ x = self.linear( self.dropout(x) ) # [bs x nvars x num_patch x patch_len]
796
+ x = x.permute(0,2,1,3) # [bs x num_patch x nvars x patch_len]
797
+ return x
798
+
799
+
800
+ class decoder_PredictHead(nn.Module):
801
+ def __init__(self, d_model, patch_len, target_patch_len, dropout):
802
+ super().__init__()
803
+ self.dropout = nn.Dropout(dropout)
804
+ self.linear = nn.Linear(d_model, target_patch_len)
805
+ self.patch_len = patch_len
806
+ self.d_model = d_model
807
+
808
+ def forward(self, x):
809
+ """
810
+ x: tensor [bs x nvars x d_model x num_patch]
811
+ output: tensor [bs x nvars x num_patch x patch_len]
812
+ """
813
+ Linear = nn.Linear(self.d_model, self.patch_len, bias=False)
814
+ Linear.weight.data = resample_patchemb(old=self.linear.weight.data.T, new_patch_len=self.patch_len).T
815
+
816
+ x = x.transpose(2,3) # [bs x nvars x num_patch x d_model]
817
+ x = Linear( self.dropout(x) ) # [bs x nvars x num_patch x patch_len]
818
+ x = x.permute(0,2,3,1) # [bs x num_patch x x patch_len x nvars]
819
+ return x.reshape(x.shape[0],-1,x.shape[3])
820
+
821
+ def resample_patchemb(old: torch.Tensor, new_patch_len: int):
822
+
823
+ assert old.dim() == 2, "输入张量应为2D (d_model, patch_size)"
824
+ if old.size(1) == new_patch_len:
825
+ return old
826
+
827
+ old = old.T
828
+ old_shape = old.size(0)
829
+ factor = new_patch_len/old_shape
830
+
831
+ # 定义辅助函数:批量resize
832
+ def resize(x_tensor, new_shape):
833
+ return F.interpolate(x_tensor.unsqueeze(0), size=new_shape, mode='linear').squeeze(0)
834
+
835
+ # 构造缩放矩阵
836
+ basis_vectors = torch.eye(old_shape, dtype=torch.float32, device=old.device)
837
+ resize_mat = resize(basis_vectors, new_patch_len).T
838
+ # 计算伪逆
839
+ resize_mat_pinv = torch.linalg.pinv(resize_mat.T)
840
+
841
+ # z_inverse = z @ resize_mat_pinv
842
+ # z_inverse_var = z_inverse.var(dim=-1).mean(dim=1).mean()
843
+ # z_var = z.var(dim=-1).mean(dim=1).mean()
844
+ # z_interpolate = z_inverse @ resize_mat.T
845
+ # z_interpolate_var = z_interpolate.var(dim=-1).mean(dim=1).mean()
846
+
847
+ # print(z_inverse_var)
848
+ # print(z_var)
849
+ # print(z_interpolate_var/z_inverse_var)
850
+
851
+
852
+ # 直接矩阵操作完成重采样
853
+ resampled_kernels = resize_mat_pinv @ old * math.sqrt(factor)
854
+
855
+ return resampled_kernels.T
856
+
857
+
858
+ def get_activation_fn(activation):
859
+ if callable(activation): return activation()
860
+ elif activation.lower() == "relu": return nn.ReLU()
861
+ elif activation.lower() == "gelu": return nn.GELU()
862
+ raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable')
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac0de5227afaac3014d45de31e650b6e277fac7c7628dd897bd49c4a6f4dad91
3
+ size 16018929
ts_generation_mixin.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any, Dict, List, Optional, Union, Callable
3
+ import torch
4
+ from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList
5
+ from transformers.generation.utils import GenerationConfig, GenerateOutput
6
+ from transformers.utils import ModelOutput
7
+
8
+
9
+ class TSGenerationMixin(GenerationMixin):
10
+ @torch.no_grad()
11
+ def generate(self,
12
+ inputs: Optional[torch.Tensor] = None,
13
+ generation_config: Optional[GenerationConfig] = None,
14
+ logits_processor: Optional[LogitsProcessorList] = None,
15
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
16
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
17
+ synced_gpus: Optional[bool] = None,
18
+ assistant_model: Optional["PreTrainedModel"] = None,
19
+ streamer: Optional["BaseStreamer"] = None,
20
+ negative_prompt_ids: Optional[torch.Tensor] = None,
21
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
22
+ revin: Optional[bool] = True,
23
+ patch_len:Optional[int] = 48,
24
+ stride_len:Optional[int]= 48,
25
+ max_output_length:Optional[int] = 96,
26
+ inference_patch_len: Optional[int] = 48,
27
+
28
+ **kwargs,
29
+ ) -> Union[GenerateOutput, torch.Tensor]:
30
+ if len(inputs.shape) != 3:
31
+ raise ValueError('Input shape must be: [batch_size, seq_len, n_vars]')
32
+
33
+ if revin:
34
+ means = inputs.mean(dim=1, keepdim=True)
35
+ stdev = inputs.std(dim=1, keepdim=True, unbiased=False) + 1e-5
36
+ inputs = (inputs - means) / stdev
37
+
38
+ batch_size,seq_len,n_vars = inputs.shape
39
+ num_patch = (max(seq_len, patch_len)-patch_len) // stride_len + 1
40
+ outputs = inputs.view(batch_size, num_patch, patch_len, n_vars)
41
+ outputs = outputs.transpose(2, 3)
42
+
43
+
44
+ model_inputs = {
45
+ "input" : outputs,
46
+ }
47
+
48
+
49
+ outputs = self(**model_inputs) #[batch_size,target_dim,n_vars]
50
+
51
+ outputs = outputs["prediction"]
52
+
53
+
54
+ if revin:
55
+
56
+ outputs = (outputs * stdev) + means
57
+
58
+ return outputs
59
+
60
+
61
+ def _update_model_kwargs_for_generation(
62
+ self,
63
+ outputs: ModelOutput,
64
+ model_kwargs: Dict[str, Any],
65
+ horizon_length: int = 1,
66
+ is_encoder_decoder: bool = False,
67
+ standardize_cache_format: bool = False,
68
+ ) -> Dict[str, Any]:
69
+
70
+ return model_kwargs
71
+
72
+