Restore AGILLM3 final inference compatibility
Browse files- agillm35.py +57 -34
agillm35.py
CHANGED
|
@@ -2513,16 +2513,23 @@ class NATHead(nn.Module):
|
|
| 2513 |
return self.proj(h)
|
| 2514 |
|
| 2515 |
|
| 2516 |
-
class SATHead(nn.Module):
|
| 2517 |
-
def __init__(self, d, mode="var", tie_weights: bool = False, embedding_weight: nn.Parameter = None):
|
| 2518 |
-
super().__init__()
|
| 2519 |
-
self.tie_weights = tie_weights
|
| 2520 |
-
|
| 2521 |
-
|
| 2522 |
-
self.proj
|
| 2523 |
-
|
| 2524 |
-
|
| 2525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2526 |
def forward(self, h_last):
|
| 2527 |
return self.proj(h_last), (self.gate(h_last[:, 0]) if self.gate else None)
|
| 2528 |
|
|
@@ -2978,7 +2985,7 @@ def infer_cfg_from_ckpt(path: pathlib.Path):
|
|
| 2978 |
|
| 2979 |
# βββββββββββββββββββββββββ Training Logic βββββββββββββββββββββββββ
|
| 2980 |
|
| 2981 |
-
def _load_infer_head_state(module: nn.Module, state: dict, name: str):
|
| 2982 |
"""Load inference heads across small checkpoint/schema drifts.
|
| 2983 |
|
| 2984 |
Some older AGILLM-4 full checkpoints were saved before the current SAT/NAT
|
|
@@ -3015,11 +3022,18 @@ def _load_infer_head_state(module: nn.Module, state: dict, name: str):
|
|
| 3015 |
notes.append("zero-filled " + ", ".join(zero_filled[:6]))
|
| 3016 |
if loaded.unexpected_keys:
|
| 3017 |
notes.append("ignored unexpected " + ", ".join(loaded.unexpected_keys[:6]))
|
| 3018 |
-
if notes:
|
| 3019 |
-
print(f"[infer-compat] {name}: " + "; ".join(notes), flush=True)
|
| 3020 |
-
|
| 3021 |
-
|
| 3022 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3023 |
return sorted(set([int(x.strip()) for x in s.split(",") if x.strip() and int(x.strip()) >= 128]))
|
| 3024 |
|
| 3025 |
def _count_enabled_params(*modules) -> int:
|
|
@@ -3690,7 +3704,8 @@ def infer(args):
|
|
| 3690 |
anchor_position=getattr(args, "anchor_position", DEFAULT_ANCHOR_POSITION),
|
| 3691 |
).to(DEV)
|
| 3692 |
ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV)
|
| 3693 |
-
|
|
|
|
| 3694 |
nat_h = NATHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV) if ("nat" in sd or args.mode == "nat") else None
|
| 3695 |
core.load_state_dict(_prepare_core_state_dict_for_load(core, sd["core"]))
|
| 3696 |
ar_h.load_state_dict(sd["ar"])
|
|
@@ -3732,10 +3747,12 @@ def infer(args):
|
|
| 3732 |
h, kvs = core(ids, causal_mask(ids.size(1), structured=use_structured_masks(args)), use_cache=True, total_seq_len=ids.size(1))
|
| 3733 |
for _ in range(args.max_new):
|
| 3734 |
logits = ar_h(h)[:, -1]
|
| 3735 |
-
logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
|
| 3736 |
-
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
|
| 3737 |
-
ids = torch.cat([ids, nxt], 1)
|
| 3738 |
-
|
|
|
|
|
|
|
| 3739 |
elif args.mode == "nat":
|
| 3740 |
# Iterative mask-predict decode (CMLM): keep the prompt fixed and fill the
|
| 3741 |
# BLANK slots, committing confident predictions each pass. Unlike the
|
|
@@ -3796,14 +3813,17 @@ def infer(args):
|
|
| 3796 |
if ids.size(1) % SAT_BLOCK != 0:
|
| 3797 |
logits = ar_h(h)[:, -1]
|
| 3798 |
logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
|
| 3799 |
-
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
|
| 3800 |
-
ids = torch.cat([ids, nxt], 1)
|
| 3801 |
-
added += 1
|
| 3802 |
-
|
| 3803 |
-
|
| 3804 |
-
|
| 3805 |
-
|
| 3806 |
-
|
|
|
|
|
|
|
|
|
|
| 3807 |
logits_all, gate = sat_h(h_buffer)
|
| 3808 |
stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
|
| 3809 |
stride = min(int(stride), logits_all.size(1))
|
|
@@ -3812,11 +3832,14 @@ def infer(args):
|
|
| 3812 |
logits = logits_all[:, i]
|
| 3813 |
logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
|
| 3814 |
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
|
| 3815 |
-
new_tokens.append(nxt)
|
| 3816 |
-
ids = torch.cat([ids, nxt], 1)
|
| 3817 |
-
added += 1
|
| 3818 |
-
if
|
| 3819 |
-
|
|
|
|
|
|
|
|
|
|
| 3820 |
new_ids = torch.cat(new_tokens, dim=1)
|
| 3821 |
mask = sat_mask_cached(new_ids.size(1), cached_len, structured=use_structured_masks(args))
|
| 3822 |
h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
|
|
|
|
| 2513 |
return self.proj(h)
|
| 2514 |
|
| 2515 |
|
| 2516 |
+
class SATHead(nn.Module):
|
| 2517 |
+
def __init__(self, d, mode="var", tie_weights: bool = False, embedding_weight: nn.Parameter = None, mlp: bool = False):
|
| 2518 |
+
super().__init__()
|
| 2519 |
+
self.tie_weights = tie_weights
|
| 2520 |
+
self.mlp = bool(mlp)
|
| 2521 |
+
if self.mlp:
|
| 2522 |
+
self.proj = nn.Sequential(
|
| 2523 |
+
nn.Linear(d, d),
|
| 2524 |
+
nn.GELU(),
|
| 2525 |
+
nn.Linear(d, VOCAB),
|
| 2526 |
+
)
|
| 2527 |
+
elif tie_weights and embedding_weight is not None:
|
| 2528 |
+
self.proj = nn.Linear(d, VOCAB, bias=False)
|
| 2529 |
+
self.proj.weight = embedding_weight
|
| 2530 |
+
else:
|
| 2531 |
+
self.proj = nn.Linear(d, VOCAB)
|
| 2532 |
+
self.gate = nn.Linear(d, 2) if mode == "var" else None
|
| 2533 |
def forward(self, h_last):
|
| 2534 |
return self.proj(h_last), (self.gate(h_last[:, 0]) if self.gate else None)
|
| 2535 |
|
|
|
|
| 2985 |
|
| 2986 |
# βββββββββββββββββββββββββ Training Logic βββββββββββββββββββββββββ
|
| 2987 |
|
| 2988 |
+
def _load_infer_head_state(module: nn.Module, state: dict, name: str):
|
| 2989 |
"""Load inference heads across small checkpoint/schema drifts.
|
| 2990 |
|
| 2991 |
Some older AGILLM-4 full checkpoints were saved before the current SAT/NAT
|
|
|
|
| 3022 |
notes.append("zero-filled " + ", ".join(zero_filled[:6]))
|
| 3023 |
if loaded.unexpected_keys:
|
| 3024 |
notes.append("ignored unexpected " + ", ".join(loaded.unexpected_keys[:6]))
|
| 3025 |
+
if notes:
|
| 3026 |
+
print(f"[infer-compat] {name}: " + "; ".join(notes), flush=True)
|
| 3027 |
+
|
| 3028 |
+
|
| 3029 |
+
def _sat_head_mlp_from_state(sd: dict) -> bool:
|
| 3030 |
+
sat_sd = sd.get("sat", {})
|
| 3031 |
+
if sd.get("delta") and "weights" in sd:
|
| 3032 |
+
sat_sd = sd["weights"].get("sat", sat_sd)
|
| 3033 |
+
return any(str(key).startswith("proj.2.") for key in sat_sd)
|
| 3034 |
+
|
| 3035 |
+
|
| 3036 |
+
def _parse_grow_plan(s: str) -> List[int]:
|
| 3037 |
return sorted(set([int(x.strip()) for x in s.split(",") if x.strip() and int(x.strip()) >= 128]))
|
| 3038 |
|
| 3039 |
def _count_enabled_params(*modules) -> int:
|
|
|
|
| 3704 |
anchor_position=getattr(args, "anchor_position", DEFAULT_ANCHOR_POSITION),
|
| 3705 |
).to(DEV)
|
| 3706 |
ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV)
|
| 3707 |
+
sat_head_mlp = bool(sd.get("sat_head_mlp", False) or _sat_head_mlp_from_state(sd))
|
| 3708 |
+
sat_h = SATHead(cfg["d"], mlp=sat_head_mlp).to(DEV)
|
| 3709 |
nat_h = NATHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV) if ("nat" in sd or args.mode == "nat") else None
|
| 3710 |
core.load_state_dict(_prepare_core_state_dict_for_load(core, sd["core"]))
|
| 3711 |
ar_h.load_state_dict(sd["ar"])
|
|
|
|
| 3747 |
h, kvs = core(ids, causal_mask(ids.size(1), structured=use_structured_masks(args)), use_cache=True, total_seq_len=ids.size(1))
|
| 3748 |
for _ in range(args.max_new):
|
| 3749 |
logits = ar_h(h)[:, -1]
|
| 3750 |
+
logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
|
| 3751 |
+
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
|
| 3752 |
+
ids = torch.cat([ids, nxt], 1)
|
| 3753 |
+
if EOS is not None and int(nxt.item()) == int(EOS):
|
| 3754 |
+
break
|
| 3755 |
+
h, kvs = core(ids[:, -1:], None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
|
| 3756 |
elif args.mode == "nat":
|
| 3757 |
# Iterative mask-predict decode (CMLM): keep the prompt fixed and fill the
|
| 3758 |
# BLANK slots, committing confident predictions each pass. Unlike the
|
|
|
|
| 3813 |
if ids.size(1) % SAT_BLOCK != 0:
|
| 3814 |
logits = ar_h(h)[:, -1]
|
| 3815 |
logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
|
| 3816 |
+
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
|
| 3817 |
+
ids = torch.cat([ids, nxt], 1)
|
| 3818 |
+
added += 1
|
| 3819 |
+
if EOS is not None and int(nxt.item()) == int(EOS):
|
| 3820 |
+
stop = True
|
| 3821 |
+
if not stop:
|
| 3822 |
+
h, kvs = core(nxt, None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
|
| 3823 |
+
cached_len = ids.size(1)
|
| 3824 |
+
h_buffer = torch.cat([h_buffer, h], dim=1)[:, -SAT_BLOCK:]
|
| 3825 |
+
|
| 3826 |
+
while added < args.max_new and not stop:
|
| 3827 |
logits_all, gate = sat_h(h_buffer)
|
| 3828 |
stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
|
| 3829 |
stride = min(int(stride), logits_all.size(1))
|
|
|
|
| 3832 |
logits = logits_all[:, i]
|
| 3833 |
logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
|
| 3834 |
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
|
| 3835 |
+
new_tokens.append(nxt)
|
| 3836 |
+
ids = torch.cat([ids, nxt], 1)
|
| 3837 |
+
added += 1
|
| 3838 |
+
if EOS is not None and int(nxt.item()) == int(EOS):
|
| 3839 |
+
stop = True
|
| 3840 |
+
break
|
| 3841 |
+
if added >= args.max_new: break
|
| 3842 |
+
if stop or added >= args.max_new: break
|
| 3843 |
new_ids = torch.cat(new_tokens, dim=1)
|
| 3844 |
mask = sat_mask_cached(new_ids.size(1), cached_len, structured=use_structured_masks(args))
|
| 3845 |
h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
|