razmars commited on
Commit
bc79359
·
verified ·
1 Parent(s): dabbd29

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +4 -3
modeling_super_linear.py CHANGED
@@ -205,6 +205,7 @@ class RLinear(nn.Module):
205
  original_norm = torch.norm(W, p=2)
206
  new_norm = torch.norm(new_W, p=2)
207
  final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
 
208
  new_W = new_W * final_scaling
209
 
210
  self.zero_shot_Linear = new_W
@@ -216,10 +217,10 @@ class RLinear(nn.Module):
216
  if x.shape[1] < self.seq_len:
217
  if self.zero_shot_Linear is None:
218
  self.transform_model(x.shape[1])
219
- #x = x.clone()
220
- #x = self.revin_layer(x, 'norm')
221
  x = F.linear(x, self.zero_shot_Linear)
222
- #x = self.revin_layer(x, 'denorm')
223
  return x
224
 
225
 
 
205
  original_norm = torch.norm(W, p=2)
206
  new_norm = torch.norm(new_W, p=2)
207
  final_scaling = original_norm / new_norm if new_norm.item() != 0 else 1.0
208
+ final_scaling = 1
209
  new_W = new_W * final_scaling
210
 
211
  self.zero_shot_Linear = new_W
 
217
  if x.shape[1] < self.seq_len:
218
  if self.zero_shot_Linear is None:
219
  self.transform_model(x.shape[1])
220
+ x = x.clone()
221
+ x = self.revin_layer(x, 'norm')
222
  x = F.linear(x, self.zero_shot_Linear)
223
+ x = self.revin_layer(x, 'denorm')
224
  return x
225
 
226