razmars commited on
Commit
1a6182c
·
verified ·
1 Parent(s): 9904639

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +9 -2
modeling_super_linear.py CHANGED
@@ -212,7 +212,7 @@ class RLinear(nn.Module):
212
  new_W = new_W * final_scaling
213
 
214
  self.zero_shot_Linear = new_W
215
- else:
216
  W = self.Linear.weight.detach()
217
  W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
218
 
@@ -225,6 +225,13 @@ class RLinear(nn.Module):
225
  )[0, 0] # drop the two singleton dims
226
 
227
  self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
 
 
 
 
 
 
 
228
 
229
 
230
  def forward(self, x):
@@ -234,7 +241,7 @@ class RLinear(nn.Module):
234
  #if self.zero_shot_Linear is None:
235
  #print(F"new Lookkback : {x.shape[1]}")
236
 
237
- self.transform_model(x.shape[1],1)
238
 
239
  x = x.clone()
240
  #x = x * (x.shape[1]/512)
 
212
  new_W = new_W * final_scaling
213
 
214
  self.zero_shot_Linear = new_W
215
+ elif mode ==2:
216
  W = self.Linear.weight.detach()
217
  W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
218
 
 
225
  )[0, 0] # drop the two singleton dims
226
 
227
  self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
228
+ else:
229
+ W = self.Linear.weight.detach()
230
+ W_3d = W.t().unsqueeze(0)
231
+ W_pool = F.adaptive_avg_pool1d(W_3d, output_size=new_lookback)
232
+ W_new = W_pool.squeeze(0).permute(2, 0).contiguous()
233
+ self.zero_shot_Linear =
234
+
235
 
236
 
237
  def forward(self, x):
 
241
  #if self.zero_shot_Linear is None:
242
  #print(F"new Lookkback : {x.shape[1]}")
243
 
244
+ self.transform_model(x.shape[1],3)
245
 
246
  x = x.clone()
247
  #x = x * (x.shape[1]/512)