OpenTransformer commited on
Commit
de95b63
Β·
verified Β·
1 Parent(s): a18d8d3

Restore AGILLM3 final inference compatibility

Browse files
Files changed (1) hide show
  1. agillm35.py +57 -34
agillm35.py CHANGED
@@ -2513,16 +2513,23 @@ class NATHead(nn.Module):
2513
  return self.proj(h)
2514
 
2515
 
2516
- class SATHead(nn.Module):
2517
- def __init__(self, d, mode="var", tie_weights: bool = False, embedding_weight: nn.Parameter = None):
2518
- super().__init__()
2519
- self.tie_weights = tie_weights
2520
- if tie_weights and embedding_weight is not None:
2521
- self.proj = nn.Linear(d, VOCAB, bias=False)
2522
- self.proj.weight = embedding_weight
2523
- else:
2524
- self.proj = nn.Linear(d, VOCAB)
2525
- self.gate = nn.Linear(d, 2) if mode == "var" else None
 
 
 
 
 
 
 
2526
  def forward(self, h_last):
2527
  return self.proj(h_last), (self.gate(h_last[:, 0]) if self.gate else None)
2528
 
@@ -2978,7 +2985,7 @@ def infer_cfg_from_ckpt(path: pathlib.Path):
2978
 
2979
  # ───────────────────────── Training Logic ─────────────────────────
2980
 
2981
- def _load_infer_head_state(module: nn.Module, state: dict, name: str):
2982
  """Load inference heads across small checkpoint/schema drifts.
2983
 
2984
  Some older AGILLM-4 full checkpoints were saved before the current SAT/NAT
@@ -3015,11 +3022,18 @@ def _load_infer_head_state(module: nn.Module, state: dict, name: str):
3015
  notes.append("zero-filled " + ", ".join(zero_filled[:6]))
3016
  if loaded.unexpected_keys:
3017
  notes.append("ignored unexpected " + ", ".join(loaded.unexpected_keys[:6]))
3018
- if notes:
3019
- print(f"[infer-compat] {name}: " + "; ".join(notes), flush=True)
3020
-
3021
-
3022
- def _parse_grow_plan(s: str) -> List[int]:
 
 
 
 
 
 
 
3023
  return sorted(set([int(x.strip()) for x in s.split(",") if x.strip() and int(x.strip()) >= 128]))
3024
 
3025
  def _count_enabled_params(*modules) -> int:
@@ -3690,7 +3704,8 @@ def infer(args):
3690
  anchor_position=getattr(args, "anchor_position", DEFAULT_ANCHOR_POSITION),
3691
  ).to(DEV)
3692
  ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV)
3693
- sat_h = SATHead(cfg["d"]).to(DEV)
 
3694
  nat_h = NATHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV) if ("nat" in sd or args.mode == "nat") else None
3695
  core.load_state_dict(_prepare_core_state_dict_for_load(core, sd["core"]))
3696
  ar_h.load_state_dict(sd["ar"])
@@ -3732,10 +3747,12 @@ def infer(args):
3732
  h, kvs = core(ids, causal_mask(ids.size(1), structured=use_structured_masks(args)), use_cache=True, total_seq_len=ids.size(1))
3733
  for _ in range(args.max_new):
3734
  logits = ar_h(h)[:, -1]
3735
- logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
3736
- nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
3737
- ids = torch.cat([ids, nxt], 1)
3738
- h, kvs = core(ids[:, -1:], None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
 
 
3739
  elif args.mode == "nat":
3740
  # Iterative mask-predict decode (CMLM): keep the prompt fixed and fill the
3741
  # BLANK slots, committing confident predictions each pass. Unlike the
@@ -3796,14 +3813,17 @@ def infer(args):
3796
  if ids.size(1) % SAT_BLOCK != 0:
3797
  logits = ar_h(h)[:, -1]
3798
  logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
3799
- nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
3800
- ids = torch.cat([ids, nxt], 1)
3801
- added += 1
3802
- h, kvs = core(nxt, None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
3803
- cached_len = ids.size(1)
3804
- h_buffer = torch.cat([h_buffer, h], dim=1)[:, -SAT_BLOCK:]
3805
-
3806
- while added < args.max_new:
 
 
 
3807
  logits_all, gate = sat_h(h_buffer)
3808
  stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
3809
  stride = min(int(stride), logits_all.size(1))
@@ -3812,11 +3832,14 @@ def infer(args):
3812
  logits = logits_all[:, i]
3813
  logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
3814
  nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
3815
- new_tokens.append(nxt)
3816
- ids = torch.cat([ids, nxt], 1)
3817
- added += 1
3818
- if added >= args.max_new: break
3819
- if added >= args.max_new: break
 
 
 
3820
  new_ids = torch.cat(new_tokens, dim=1)
3821
  mask = sat_mask_cached(new_ids.size(1), cached_len, structured=use_structured_masks(args))
3822
  h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
 
2513
  return self.proj(h)
2514
 
2515
 
2516
+ class SATHead(nn.Module):
2517
+ def __init__(self, d, mode="var", tie_weights: bool = False, embedding_weight: nn.Parameter = None, mlp: bool = False):
2518
+ super().__init__()
2519
+ self.tie_weights = tie_weights
2520
+ self.mlp = bool(mlp)
2521
+ if self.mlp:
2522
+ self.proj = nn.Sequential(
2523
+ nn.Linear(d, d),
2524
+ nn.GELU(),
2525
+ nn.Linear(d, VOCAB),
2526
+ )
2527
+ elif tie_weights and embedding_weight is not None:
2528
+ self.proj = nn.Linear(d, VOCAB, bias=False)
2529
+ self.proj.weight = embedding_weight
2530
+ else:
2531
+ self.proj = nn.Linear(d, VOCAB)
2532
+ self.gate = nn.Linear(d, 2) if mode == "var" else None
2533
  def forward(self, h_last):
2534
  return self.proj(h_last), (self.gate(h_last[:, 0]) if self.gate else None)
2535
 
 
2985
 
2986
  # ───────────────────────── Training Logic ─────────────────────────
2987
 
2988
+ def _load_infer_head_state(module: nn.Module, state: dict, name: str):
2989
  """Load inference heads across small checkpoint/schema drifts.
2990
 
2991
  Some older AGILLM-4 full checkpoints were saved before the current SAT/NAT
 
3022
  notes.append("zero-filled " + ", ".join(zero_filled[:6]))
3023
  if loaded.unexpected_keys:
3024
  notes.append("ignored unexpected " + ", ".join(loaded.unexpected_keys[:6]))
3025
+ if notes:
3026
+ print(f"[infer-compat] {name}: " + "; ".join(notes), flush=True)
3027
+
3028
+
3029
+ def _sat_head_mlp_from_state(sd: dict) -> bool:
3030
+ sat_sd = sd.get("sat", {})
3031
+ if sd.get("delta") and "weights" in sd:
3032
+ sat_sd = sd["weights"].get("sat", sat_sd)
3033
+ return any(str(key).startswith("proj.2.") for key in sat_sd)
3034
+
3035
+
3036
+ def _parse_grow_plan(s: str) -> List[int]:
3037
  return sorted(set([int(x.strip()) for x in s.split(",") if x.strip() and int(x.strip()) >= 128]))
3038
 
3039
  def _count_enabled_params(*modules) -> int:
 
3704
  anchor_position=getattr(args, "anchor_position", DEFAULT_ANCHOR_POSITION),
3705
  ).to(DEV)
3706
  ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV)
3707
+ sat_head_mlp = bool(sd.get("sat_head_mlp", False) or _sat_head_mlp_from_state(sd))
3708
+ sat_h = SATHead(cfg["d"], mlp=sat_head_mlp).to(DEV)
3709
  nat_h = NATHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV) if ("nat" in sd or args.mode == "nat") else None
3710
  core.load_state_dict(_prepare_core_state_dict_for_load(core, sd["core"]))
3711
  ar_h.load_state_dict(sd["ar"])
 
3747
  h, kvs = core(ids, causal_mask(ids.size(1), structured=use_structured_masks(args)), use_cache=True, total_seq_len=ids.size(1))
3748
  for _ in range(args.max_new):
3749
  logits = ar_h(h)[:, -1]
3750
+ logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
3751
+ nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
3752
+ ids = torch.cat([ids, nxt], 1)
3753
+ if EOS is not None and int(nxt.item()) == int(EOS):
3754
+ break
3755
+ h, kvs = core(ids[:, -1:], None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
3756
  elif args.mode == "nat":
3757
  # Iterative mask-predict decode (CMLM): keep the prompt fixed and fill the
3758
  # BLANK slots, committing confident predictions each pass. Unlike the
 
3813
  if ids.size(1) % SAT_BLOCK != 0:
3814
  logits = ar_h(h)[:, -1]
3815
  logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
3816
+ nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
3817
+ ids = torch.cat([ids, nxt], 1)
3818
+ added += 1
3819
+ if EOS is not None and int(nxt.item()) == int(EOS):
3820
+ stop = True
3821
+ if not stop:
3822
+ h, kvs = core(nxt, None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
3823
+ cached_len = ids.size(1)
3824
+ h_buffer = torch.cat([h_buffer, h], dim=1)[:, -SAT_BLOCK:]
3825
+
3826
+ while added < args.max_new and not stop:
3827
  logits_all, gate = sat_h(h_buffer)
3828
  stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
3829
  stride = min(int(stride), logits_all.size(1))
 
3832
  logits = logits_all[:, i]
3833
  logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
3834
  nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
3835
+ new_tokens.append(nxt)
3836
+ ids = torch.cat([ids, nxt], 1)
3837
+ added += 1
3838
+ if EOS is not None and int(nxt.item()) == int(EOS):
3839
+ stop = True
3840
+ break
3841
+ if added >= args.max_new: break
3842
+ if stop or added >= args.max_new: break
3843
  new_ids = torch.cat(new_tokens, dim=1)
3844
  mask = sat_mask_cached(new_ids.size(1), cached_len, structured=use_structured_masks(args))
3845
  h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))