razmars commited on
Commit
b5db567
·
verified ·
1 Parent(s): 7486e1b

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +12 -15
modeling_super_linear.py CHANGED
@@ -213,26 +213,23 @@ class RLinear(nn.Module):
213
  def forward(self, x):
214
  # x: [Batch, Input length,Channel]
215
  x_shape = x.shape
216
- print(x.shape)
217
- if len(x_shape) == 2:
218
- x = x.unsqueeze(-1)
219
-
220
- B,L,V = x.shape
221
- if L < self.seq_len and self.zero_shot_Linear is None:
222
- print(F"New Lookback :{L}")
223
- self.transform_model(L)
224
-
225
- if L < self.seq_len:
226
  x = x.clone()
227
  x = self.revin_layer(x, 'norm')
228
  x = F.linear(x, self.zero_shot_Linear)
229
  x = self.revin_layer(x, 'denorm')
 
230
 
231
- else:
232
- x = x.clone()
233
- x = self.revin_layer(x, 'norm')
234
- x = self.Linear(x.permute(0,2,1)).permute(0,2,1).clone()
235
- x = self.revin_layer(x, 'denorm')
 
 
 
236
 
237
  if len(x_shape) == 2:
238
  x = x.squeeze(-1)
 
213
  def forward(self, x):
214
  # x: [Batch, Input length,Channel]
215
  x_shape = x.shape
216
+ if x.shape[1] < self.seq_len:
217
+ if self.zero_shot_Linear is None:
218
+ self.transform_model(L)
 
 
 
 
 
 
 
219
  x = x.clone()
220
  x = self.revin_layer(x, 'norm')
221
  x = F.linear(x, self.zero_shot_Linear)
222
  x = self.revin_layer(x, 'denorm')
223
+ return x
224
 
225
+
226
+ if len(x_shape) == 2:
227
+ x = x.unsqueeze(-1)
228
+
229
+ x = x.clone()
230
+ x = self.revin_layer(x, 'norm')
231
+ x = self.Linear(x.permute(0,2,1)).permute(0,2,1).clone()
232
+ x = self.revin_layer(x, 'denorm')
233
 
234
  if len(x_shape) == 2:
235
  x = x.squeeze(-1)