Update modeling_super_linear.py
Browse files- modeling_super_linear.py +9 -2
modeling_super_linear.py
CHANGED
|
@@ -212,7 +212,7 @@ class RLinear(nn.Module):
|
|
| 212 |
new_W = new_W * final_scaling
|
| 213 |
|
| 214 |
self.zero_shot_Linear = new_W
|
| 215 |
-
|
| 216 |
W = self.Linear.weight.detach()
|
| 217 |
W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
|
| 218 |
|
|
@@ -225,6 +225,13 @@ class RLinear(nn.Module):
|
|
| 225 |
)[0, 0] # drop the two singleton dims
|
| 226 |
|
| 227 |
self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
|
| 230 |
def forward(self, x):
|
|
@@ -234,7 +241,7 @@ class RLinear(nn.Module):
|
|
| 234 |
#if self.zero_shot_Linear is None:
|
| 235 |
#print(F"new Lookkback : {x.shape[1]}")
|
| 236 |
|
| 237 |
-
self.transform_model(x.shape[1],
|
| 238 |
|
| 239 |
x = x.clone()
|
| 240 |
#x = x * (x.shape[1]/512)
|
|
|
|
| 212 |
new_W = new_W * final_scaling
|
| 213 |
|
| 214 |
self.zero_shot_Linear = new_W
|
| 215 |
+
elif mode ==2:
|
| 216 |
W = self.Linear.weight.detach()
|
| 217 |
W4d = W.unsqueeze(0).unsqueeze(0) # (1, 1, out, in)
|
| 218 |
|
|
|
|
| 225 |
)[0, 0] # drop the two singleton dims
|
| 226 |
|
| 227 |
self.zero_shot_Linear = new_W # shape (self.horizon, new_lookback)
|
| 228 |
+
else:
|
| 229 |
+
W = self.Linear.weight.detach()
|
| 230 |
+
W_3d = W.t().unsqueeze(0)
|
| 231 |
+
W_pool = F.adaptive_avg_pool1d(W_3d, output_size=new_lookback)
|
| 232 |
+
W_new = W_pool.squeeze(0).permute(2, 0).contiguous()
|
| 233 |
+
self.zero_shot_Linear =
|
| 234 |
+
|
| 235 |
|
| 236 |
|
| 237 |
def forward(self, x):
|
|
|
|
| 241 |
#if self.zero_shot_Linear is None:
|
| 242 |
#print(F"new Lookkback : {x.shape[1]}")
|
| 243 |
|
| 244 |
+
self.transform_model(x.shape[1],3)
|
| 245 |
|
| 246 |
x = x.clone()
|
| 247 |
#x = x * (x.shape[1]/512)
|