razmars commited on
Commit
2256b8f
·
verified ·
1 Parent(s): db0c818

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +2 -3
modeling_super_linear.py CHANGED
@@ -204,8 +204,8 @@ class RLinear(nn.Module):
204
  def transform_model(self,new_lookback,mode):
205
  if mode == 1:
206
  W = self.Linear.weight.detach()
207
- #new_W = W[:, -new_lookback:]
208
- new_W = W[:, :new_lookback]
209
  original_norm = torch.norm(W, p=2)
210
  new_norm = torch.norm(new_W, p=2)
211
  final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
@@ -227,7 +227,6 @@ class RLinear(nn.Module):
227
  self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
228
 
229
 
230
-
231
  def forward(self, x):
232
  # x: [Batch, Input length,Channel]
233
  x_shape = x.shape
 
204
  def transform_model(self,new_lookback,mode):
205
  if mode == 1:
206
  W = self.Linear.weight.detach()
207
+ new_W = W[:, -new_lookback:]
208
+ #new_W = W[:, :new_lookback]
209
  original_norm = torch.norm(W, p=2)
210
  new_norm = torch.norm(new_W, p=2)
211
  final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
 
227
  self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
228
 
229
 
 
230
  def forward(self, x):
231
  # x: [Batch, Input length,Channel]
232
  x_shape = x.shape