razmars commited on
Commit
6bcd495
·
verified ·
1 Parent(s): 2bf5e6c

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +11 -11
modeling_super_linear.py CHANGED
@@ -199,14 +199,15 @@ class RLinear(nn.Module):
199
  self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
200
  self.zero_shot_Linear = None
201
 
202
- def transform_model(self,new_lookback):
203
- W = self.Linear.weight.detach()
204
- new_W = W[:, -new_lookback:]
205
- original_norm = torch.norm(W, p=2)
206
- new_norm = torch.norm(new_W, p=2)
207
- final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
208
- #final_scaling = 1
209
- new_W = new_W * final_scaling
 
210
 
211
  self.zero_shot_Linear = new_W
212
 
@@ -219,13 +220,12 @@ class RLinear(nn.Module):
219
  #print(F"new Lookkback : {x.shape[1]}")
220
  self.transform_model(x.shape[1])
221
 
222
-
223
  x = x.clone()
224
- x = x * (x.shape[1]/512)
225
  x = self.revin_layer(x, 'norm')
226
  x = F.linear(x, self.zero_shot_Linear)
227
  x = self.revin_layer(x, 'denorm')
228
- x = x * (512/x.shape[1])
229
  return x
230
 
231
 
 
199
  self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False)
200
  self.zero_shot_Linear = None
201
 
202
+ def transform_model(self,new_lookback,mode):
203
+ if mode == 2:
204
+ W = self.Linear.weight.detach()
205
+ new_W = W[:, -new_lookback:]
206
+ original_norm = torch.norm(W, p=2)
207
+ new_norm = torch.norm(new_W, p=2)
208
+ final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
209
+ #final_scaling = 1
210
+ new_W = new_W * final_scaling
211
 
212
  self.zero_shot_Linear = new_W
213
 
 
220
  #print(F"new Lookkback : {x.shape[1]}")
221
  self.transform_model(x.shape[1])
222
 
 
223
  x = x.clone()
224
+ #x = x * (x.shape[1]/512)
225
  x = self.revin_layer(x, 'norm')
226
  x = F.linear(x, self.zero_shot_Linear)
227
  x = self.revin_layer(x, 'denorm')
228
+ #x = x * (512/x.shape[1])
229
  return x
230
 
231