Update modeling_super_linear.py
Browse files- modeling_super_linear.py +2 -13
modeling_super_linear.py
CHANGED
|
@@ -464,11 +464,9 @@ class superLinear(nn.Module):
|
|
| 464 |
else:
|
| 465 |
out, self.moe_loss = self.moe(x)
|
| 466 |
|
| 467 |
-
|
| 468 |
-
print(F"max_horizon: {self.max_horizon}")
|
| 469 |
-
print(F"self.inf_pred_len: {self.inf_pred_len}")
|
| 470 |
if self.auto_regressive and self.max_horizon < self.inf_pred_len:
|
| 471 |
-
print("bitch")
|
| 472 |
outputs = [out]
|
| 473 |
ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
|
| 474 |
for i in range(0, self.inf_pred_len, self.max_horizon):
|
|
@@ -497,18 +495,9 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
|
| 497 |
|
| 498 |
|
| 499 |
# the backbone keeps its own Config dataclass, so build one on‑the‑fly:
|
| 500 |
-
#print("SuperLinearForCausalLM")
|
| 501 |
-
#print(config)
|
| 502 |
backbone_cfg = type("Cfg", (), config.to_dict())()
|
| 503 |
self.args = backbone_cfg
|
| 504 |
self.backbone = superLinear(backbone_cfg)
|
| 505 |
-
|
| 506 |
-
# optional final projection: map backbone output to discrete bins
|
| 507 |
-
# (delete if your model already returns logits over a vocabulary)
|
| 508 |
-
'''self.vocab_size = getattr(config, "vocab_size", None)
|
| 509 |
-
if self.vocab_size is not None:
|
| 510 |
-
self.lm_head = nn.Linear(backbone_cfg.pred_len, self.vocab_size)'''
|
| 511 |
-
|
| 512 |
self.post_init()
|
| 513 |
|
| 514 |
# ------------------------------------------------------------------
|
|
|
|
| 464 |
else:
|
| 465 |
out, self.moe_loss = self.moe(x)
|
| 466 |
|
| 467 |
+
|
|
|
|
|
|
|
| 468 |
if self.auto_regressive and self.max_horizon < self.inf_pred_len:
|
| 469 |
+
#print("bitch")
|
| 470 |
outputs = [out]
|
| 471 |
ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
|
| 472 |
for i in range(0, self.inf_pred_len, self.max_horizon):
|
|
|
|
| 495 |
|
| 496 |
|
| 497 |
# the backbone keeps its own Config dataclass, so build one on‑the‑fly:
|
|
|
|
|
|
|
| 498 |
backbone_cfg = type("Cfg", (), config.to_dict())()
|
| 499 |
self.args = backbone_cfg
|
| 500 |
self.backbone = superLinear(backbone_cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
self.post_init()
|
| 502 |
|
| 503 |
# ------------------------------------------------------------------
|