razmars commited on
Commit
0d2aee9
·
verified ·
1 Parent(s): 376da5d

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +2 -13
modeling_super_linear.py CHANGED
@@ -464,11 +464,9 @@ class superLinear(nn.Module):
464
  else:
465
  out, self.moe_loss = self.moe(x)
466
 
467
- print(F"auto_regressive: {self.auto_regressive}")
468
- print(F"max_horizon: {self.max_horizon}")
469
- print(F"self.inf_pred_len: {self.inf_pred_len}")
470
  if self.auto_regressive and self.max_horizon < self.inf_pred_len:
471
- print("bitch")
472
  outputs = [out]
473
  ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
474
  for i in range(0, self.inf_pred_len, self.max_horizon):
@@ -497,18 +495,9 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
497
 
498
 
499
  # the backbone keeps its own Config dataclass, so build one on‑the‑fly:
500
- #print("SuperLinearForCausalLM")
501
- #print(config)
502
  backbone_cfg = type("Cfg", (), config.to_dict())()
503
  self.args = backbone_cfg
504
  self.backbone = superLinear(backbone_cfg)
505
-
506
- # optional final projection: map backbone output to discrete bins
507
- # (delete if your model already returns logits over a vocabulary)
508
- '''self.vocab_size = getattr(config, "vocab_size", None)
509
- if self.vocab_size is not None:
510
- self.lm_head = nn.Linear(backbone_cfg.pred_len, self.vocab_size)'''
511
-
512
  self.post_init()
513
 
514
  # ------------------------------------------------------------------
 
464
  else:
465
  out, self.moe_loss = self.moe(x)
466
 
467
+
 
 
468
  if self.auto_regressive and self.max_horizon < self.inf_pred_len:
469
+ #print("bitch")
470
  outputs = [out]
471
  ar_x = torch.cat([x, out], dim=1)[:, -self.seq_len:]
472
  for i in range(0, self.inf_pred_len, self.max_horizon):
 
495
 
496
 
497
  # the backbone keeps its own Config dataclass, so build one on‑the‑fly:
 
 
498
  backbone_cfg = type("Cfg", (), config.to_dict())()
499
  self.args = backbone_cfg
500
  self.backbone = superLinear(backbone_cfg)
 
 
 
 
 
 
 
501
  self.post_init()
502
 
503
  # ------------------------------------------------------------------