razmars commited on
Commit
9d7bcdf
·
verified ·
1 Parent(s): 554ce9f

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +9 -4
modeling_super_linear.py CHANGED
@@ -239,13 +239,18 @@ class RLinear(nn.Module):
239
  #print(F"new Lookkback : {x.shape[1]}")
240
 
241
  self.transform_model(x.shape[1],1)
 
 
 
 
 
242
 
243
  #x = x * (x.shape[1]/512)
244
- x = self.revin_layer(x, 'norm')
245
- x = F.linear(x, self.zero_shot_Linear)
246
- x = self.revin_layer(x, 'denorm')
247
  #x = x * (512/x.shape[1])
248
- return x
249
 
250
 
251
  if len(x_shape) == 2:
 
239
  #print(F"new Lookkback : {x.shape[1]}")
240
 
241
  self.transform_model(x.shape[1],1)
242
+ new_x = x.unsqueeze(-1)
243
+ seq_last = new_x[:,-1:,:].detach()
244
+ new_x = new_x - seq_last
245
+ new_x = self.Linear(new_x.permute(0,2,1)).permute(0,2,1)
246
+ return new_x + seq_last
247
 
248
  #x = x * (x.shape[1]/512)
249
+ #x = self.revin_layer(x, 'norm')
250
+ #x = F.linear(x, self.zero_shot_Linear)
251
+ #x = self.revin_layer(x, 'denorm')
252
  #x = x * (512/x.shape[1])
253
+ #return x
254
 
255
 
256
  if len(x_shape) == 2: