razmars commited on
Commit
eeb454f
·
verified ·
1 Parent(s): 21831ca

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +7 -1
modeling_super_linear.py CHANGED
@@ -216,11 +216,17 @@ class RLinear(nn.Module):
216
  x_shape = x.shape
217
  if x.shape[1] < self.seq_len:
218
  if self.zero_shot_Linear is None:
 
219
  self.transform_model(x.shape[1])
 
 
 
220
  x = x.clone()
221
  x = self.revin_layer(x, 'norm')
222
- x = F.linear(x, self.zero_shot_Linear).clone()
223
  x = self.revin_layer(x, 'denorm')
 
 
224
  return x
225
 
226
 
 
216
  x_shape = x.shape
217
  if x.shape[1] < self.seq_len:
218
  if self.zero_shot_Linear is None:
219
+ print(F"new Lookkback : {x.shape[1]}")
220
  self.transform_model(x.shape[1])
221
+
222
+ if len(x_shape) == 2:
223
+ x = x.unsqueeze(-1)
224
  x = x.clone()
225
  x = self.revin_layer(x, 'norm')
226
+ x = F.linear(x.permute(0,2,1), self.zero_shot_Linear).permute(0,2,1).clone()
227
  x = self.revin_layer(x, 'denorm')
228
+ if len(x_shape) == 2:
229
+ x = x.squeeze(-1)
230
  return x
231
 
232