OpenTransformer's picture
download
raw
195 kB
#!/usr/bin/env python3
"""AGILLM4.1 mainline single-file trainer/inference runtime.
AGILLM4.1 is the promoted AGILLM4 mainline evolved from the AGILLM3.5
prototype, and it is larger than AGILLM3/AGILLM3.5. Resumed checkpoints are
the source of truth for the exact architecture, with AGILLM4 presets available
for fresh starts. This file is mechanically folded from AGILLM4 plus
compatibility patches:
- DeepSeek-V4-Pro tokenizer/checkpoint support by default
- DeepSeek-V3.2 legacy compatibility support through the agillm35 shim
- AR + SAT checkpoint schema compatibility; NAT can be disabled with --agillm3_compat
- DiffusionBlock training support and optional async side-update ingestion
"""
from __future__ import annotations
# Single-file module alias: helper code still imports the historical module names.
import sys as _agillm41_sys
_agillm41_sys.modules.setdefault("nB300_agillm4", _agillm41_sys.modules[__name__])
_agillm41_sys.modules.setdefault("agillm35", _agillm41_sys.modules[__name__])
_agillm41_sys.modules.setdefault("agillm41", _agillm41_sys.modules[__name__])
# ===== BEGIN anchor_memory.py =====
#!/usr/bin/env python3
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class AnchorMemoryConfig:
d_model: int
heads: int
anchor_stride: int = 256
max_anchors: int = 2048
dropout: float = 0.0
class AnchorCompressor(nn.Module):
"""Compress local token spans into trainable anchor vectors."""
def __init__(self, d_model: int, anchor_stride: int):
super().__init__()
self.anchor_stride = anchor_stride
self.score = nn.Linear(d_model, 1)
self.mix = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
bsz, seq, dim = x.shape
pad = (-seq) % self.anchor_stride
if pad:
x = F.pad(x, (0, 0, 0, pad))
chunks = x.view(bsz, -1, self.anchor_stride, dim)
weights = self.score(chunks).softmax(dim=2)
pooled = (chunks * weights).sum(dim=2)
return pooled + self.mix(pooled)
class AnchorMemoryLayer(nn.Module):
"""Local-token stream reads from a bounded bank of learned anchors."""
def __init__(self, cfg: AnchorMemoryConfig):
super().__init__()
self.cfg = cfg
self.compress = AnchorCompressor(cfg.d_model, cfg.anchor_stride)
self.q_ln = nn.LayerNorm(cfg.d_model)
self.mem_ln = nn.LayerNorm(cfg.d_model)
self.read = nn.MultiheadAttention(
cfg.d_model,
cfg.heads,
dropout=cfg.dropout,
batch_first=True,
)
self.gate = nn.Sequential(nn.Linear(2 * cfg.d_model, cfg.d_model), nn.Sigmoid())
self.out_ln = nn.LayerNorm(cfg.d_model)
def forward(
self,
x: torch.Tensor,
memory: torch.Tensor | None = None,
*,
detach_memory: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
new_anchors = self.compress(x)
if detach_memory:
new_anchors = new_anchors.detach()
if memory is None:
bank = new_anchors
else:
bank = torch.cat([memory, new_anchors], dim=1)
if bank.size(1) > self.cfg.max_anchors:
bank = bank[:, -self.cfg.max_anchors :]
recalled, _ = self.read(self.q_ln(x), self.mem_ln(bank), self.mem_ln(bank), need_weights=False)
gate = self.gate(torch.cat([x, recalled], dim=-1))
mixed = x + gate * recalled
return self.out_ln(mixed), bank
def smoke_test() -> None:
cfg = AnchorMemoryConfig(d_model=128, heads=8, anchor_stride=32, max_anchors=64)
layer = AnchorMemoryLayer(cfg)
x = torch.randn(2, 256, 128)
y, memory = layer(x)
assert y.shape == x.shape
assert memory.shape == (2, 8, 128)
y2, memory2 = layer(x, memory)
assert y2.shape == x.shape
assert memory2.shape == (2, 16, 128)
print("anchor_memory smoke OK", y.shape, memory2.shape)
# ===== END anchor_memory.py =====
# ===== BEGIN fused_ce.py =====
"""Fused cross-entropy: streams over the VOCAB dimension (online-softmax) so the
[N x V] logit matrix is NEVER materialized -- only [N x vchunk]. Custom backward
recomputes softmax per vocab-chunk (grad = softmax - onehot). This is the
DiffusionBlocks 'process in chunks, don't hold the whole thing' idea applied to
the output head instead of network depth."""
import torch
class FusedCE(torch.autograd.Function):
@staticmethod
def forward(ctx, h, W, tgt, vchunk=16384):
with torch.cuda.amp.autocast(enabled=False):
hf = h.float()
Wf = W.float()
N, d = h.shape
V = W.shape[0]
m = torch.full((N,), -1e30, device=h.device, dtype=torch.float32)
s = torch.zeros(N, device=h.device, dtype=torch.float32)
zt = torch.zeros(N, device=h.device, dtype=torch.float32)
for c in range(0, V, vchunk):
lg = hf @ Wf[c:c+vchunk].T # [N,vchunk] transient only
cm = lg.max(1).values
nm = torch.maximum(m, cm)
s = s * torch.exp(m - nm) + torch.exp(lg - nm[:, None]).sum(1)
m = nm
ic = (tgt >= c) & (tgt < c+vchunk)
if ic.any():
zt[ic] = lg[ic, tgt[ic] - c].float()
lse = m + torch.log(s)
ctx.save_for_backward(h, W, tgt, lse)
ctx.vchunk = vchunk
return (lse - zt).mean()
@staticmethod
def backward(ctx, go):
h, W, tgt, lse = ctx.saved_tensors
vc = ctx.vchunk
N, d = h.shape
V = W.shape[0]
with torch.cuda.amp.autocast(enabled=False):
hf = h.float()
Wc_all = W.float()
gh = torch.zeros_like(hf)
gW = torch.zeros(W.shape, device=W.device, dtype=torch.float32)
sc = float(go) / N
for c in range(0, V, vc):
Wc = Wc_all[c:c+vc]
p = torch.exp(hf @ Wc.T - lse[:, None]) # softmax chunk [N,vchunk]
ic = (tgt >= c) & (tgt < c+vc)
if ic.any():
p[ic, tgt[ic] - c] -= 1.0
p *= sc
gh += p @ Wc
gW[c:c+vc] += p.T @ hf
return gh.to(h.dtype), gW.to(W.dtype), None, None
def fused_ce(h, W, tgt, vchunk=16384):
return FusedCE.apply(h.reshape(-1, h.size(-1)), W, tgt.reshape(-1), vchunk)
# ===== END fused_ce.py =====
# ===== BEGIN dblocks_train.py =====
"""DiffusionBlocks training mode folded into AGILLM-4 (gated by --dblock).
Block-wise EDM denoising on the real Encoder blocks, supervising AR + SAT(fixed+var)
+ NAT each step on ONE block, with grad-checkpointed layers and fused vocab-streaming
CE. Reuses the live data stream / optimizer / checkpointing of nB300_agillm4.
Lazy-imports nB300 inside functions to avoid a circular import.
"""
import math
import random
import time
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as _ck
SD = 0.5
def _profile_active(state, args):
limit = int(getattr(args, "profile_steps", 0) or 0)
return limit > 0 and int(state.get("profile_n", 0)) < limit
def _profile_add(state, name, seconds):
if seconds is None:
return
prof = state.setdefault("profile_times", defaultdict(float))
prof[name] += float(seconds)
def _profile_tic(enabled):
if not enabled:
return None
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.perf_counter()
def _profile_toc(state, name, start):
if start is None:
return
if torch.cuda.is_available():
torch.cuda.synchronize()
_profile_add(state, name, time.perf_counter() - start)
def _profile_step_done(state, args):
limit = int(getattr(args, "profile_steps", 0) or 0)
if limit <= 0:
return
n_prev = int(state.get("profile_n", 0))
if n_prev >= limit:
return
state["profile_n"] = n_prev + 1
n = int(state["profile_n"])
log_every = max(1, int(getattr(args, "profile_log_every", 25) or 25))
if n % log_every != 0 and n != limit:
return
times = state.get("profile_times", {})
keys = [
"data_stream", "tensor", "setup",
"ar_forward", "ar_ce", "ar_backward",
"sat_forward", "sat_ce", "sat_backward",
"nat_forward", "nat_ce", "nat_backward",
"opt_step", "step_total",
]
parts = []
for key in keys:
val = float(times.get(key, 0.0)) * 1000.0 / max(1, n)
if val > 0.01:
parts.append(f"{key}={val:.2f}ms")
print(f"[profile] n={n}/{limit} avg " + " ".join(parts), flush=True)
def _cdf(x):
return 0.5 * (1 + math.erf(x / math.sqrt(2)))
def _ppf(p):
return float(torch.erfinv(torch.tensor(2 * p - 1.0)) * math.sqrt(2))
def _block_sigmas(B, smin=0.002, smax=80.0, pm=-1.2, ps=1.2):
a, b = _cdf((math.log(smin) - pm) / ps), _cdf((math.log(smax) - pm) / ps)
return [float(np.exp(pm + ps * _ppf(a + (b - a) * (i / B)))) for i in range(B + 1)]
def _edm_pre(s):
s = s[:, None, None]
return SD**2 / (s**2 + SD**2), s * SD / (s**2 + SD**2) ** 0.5, 1 / (s**2 + SD**2) ** 0.5
def _edm_w(s, wmax=5.0):
return float(((s**2 + SD**2) / (s * SD) ** 2).clamp(max=wmax).mean())
def _dblock_init(core, args):
B = int(getattr(args, "dblock_blocks", 4))
L = len(core.blocks)
sp = max(1, L // B)
asg = [list(range(i * sp, (i + 1) * sp)) for i in range(B)]
asg[-1] = list(range((B - 1) * sp, L))
bsig = _block_sigmas(B)
schedule = getattr(args, "dblock_schedule", "loss_balanced")
print(f"[dblock] DiffusionBlocks mode: {L} layers -> {B} blocks {asg}")
print(f"[dblock] schedule={schedule} sigma boundaries: {[round(x, 3) for x in bsig]}")
return {
"B": B,
"assign": asg,
"bsig": bsig,
"step": 0,
"counts": [0 for _ in range(B)],
"loss_ema": [None for _ in range(B)],
}
def _choose_block(state, args):
B = state["B"]
schedule = str(getattr(args, "dblock_schedule", "loss_balanced") or "loss_balanced").lower()
step = int(state.get("step", 0))
counts = state.setdefault("counts", [0 for _ in range(B)])
emas = state.setdefault("loss_ema", [None for _ in range(B)])
if schedule == "random":
return random.randrange(B)
if schedule == "roundrobin":
return step % B
explore = float(getattr(args, "dblock_explore", 0.05))
warmup = int(getattr(args, "dblock_warmup_steps", max(8, B * 2)))
if step < warmup or any(c == 0 for c in counts):
return min(range(B), key=lambda i: (counts[i], i))
if explore > 0.0 and random.random() < explore:
return min(range(B), key=lambda i: (counts[i], i))
return max(range(B), key=lambda i: (-1.0 if emas[i] is None else emas[i], -counts[i]))
def _sample_sigma(ids, lo, hi, args, state):
cur_step = int(state.get("step", 0))
curriculum = int(getattr(args, "dblock_sigma_curriculum_steps", 0))
if curriculum > 0:
frac = min(1.0, max(0.05, (cur_step + 1) / float(curriculum)))
hi = lo * ((hi / max(lo, 1e-8)) ** frac)
sig_np = np.exp(
np.random.uniform(
math.log(max(lo, 1e-4)),
math.log(max(hi, lo + 1e-4)),
ids.size(0),
).astype("float32")
)
return torch.from_numpy(sig_np).to(ids.device)
def _maybe_log(
state,
args,
bi,
layers,
ar_val,
sat_val,
nat_val,
total_val,
peak_alloc,
peak_reserved,
objective=None,
raw_avg=None,
raw_total=None,
edm_weight=None,
):
log_every = int(getattr(args, "dblock_log_every", 50))
step = int(state.get("step", 0))
if log_every <= 0 or step % log_every != 0:
return
counts = ",".join(str(x) for x in state.get("counts", []))
emas = ",".join("nan" if x is None else f"{x:.2f}" for x in state.get("loss_ema", []))
mem = ""
if peak_alloc is not None:
mem = f" peak_alloc={peak_alloc:.2f}GB peak_reserved={peak_reserved:.2f}GB"
display = float(raw_avg) if raw_avg is not None and math.isfinite(float(raw_avg)) else float(total_val)
raw_part = ""
if raw_total is not None:
raw_part += f" raw_sum={float(raw_total):.3f}"
if edm_weight is not None:
raw_part += f" edm_w={float(edm_weight):.3f}"
print(
f"[dblock] step={step} block={bi} obj={objective or 'mixed'} layers={layers} "
f"loss={display:.3f} weighted={total_val:.3f} ar={ar_val:.3f} sat={sat_val:.3f} nat={nat_val:.3f}"
f"{raw_part} counts=[{counts}] ema=[{emas}]{mem}",
flush=True,
)
def _update_stats(state, bi, loss_value):
B = state["B"]
counts = state.setdefault("counts", [0 for _ in range(B)])
emas = state.setdefault("loss_ema", [None for _ in range(B)])
counts[bi] += 1
prev = emas[bi]
beta = 0.96
emas[bi] = float(loss_value) if prev is None else beta * float(prev) + (1.0 - beta) * float(loss_value)
state["step"] = int(state.get("step", 0)) + 1
def _activation_offload_enabled(args):
return bool(getattr(args, "dblock_activation_offload", False)) and torch.cuda.is_available()
def _activation_offload_hooks(args):
min_bytes = int(float(getattr(args, "dblock_activation_offload_min_mb", 1.0) or 1.0) * 1024 * 1024)
def pack(t):
if not torch.is_tensor(t) or not t.is_cuda or not t.is_floating_point() or t.numel() * t.element_size() < min_bytes:
return t
return ("cpu_offload", t.device, t.detach().to("cpu", non_blocking=True))
def unpack(x):
if isinstance(x, tuple) and len(x) == 3 and x[0] == "cpu_offload":
_, dev, cpu_t = x
return cpu_t.to(dev, non_blocking=True)
return x
return torch.autograd.graph.saved_tensors_hooks(pack, unpack)
def _run_block(block, x, mask, use_checkpoint, args=None):
if use_checkpoint:
return _ck.checkpoint(lambda y, block=block: block(y, mask), x, use_reentrant=False)
if args is not None and _activation_offload_enabled(args):
with _activation_offload_hooks(args):
return block(x, mask)
return block(x, mask)
def _dblock_checkpoint_this_layer(args, base_enabled, layer_pos, layer_count=None):
if not base_enabled:
return False
pos = int(layer_pos)
count = int(layer_count or 0)
skip_tail = max(0, int(getattr(args, "dblock_checkpoint_skip_tail", 0) or 0))
if skip_tail > 0 and count > 0 and pos >= max(0, count - skip_tail):
return False
stride = int(getattr(args, "dblock_checkpoint_stride", 1) or 1)
if stride <= 0:
return False
if stride == 1:
return True
return (pos % stride) == 0
def _sample_token_loss_inputs(hidden, targets, max_tokens):
max_tokens = int(max_tokens or 0)
if max_tokens <= 0:
return hidden.contiguous(), targets.contiguous(), int(targets.numel()), int(targets.numel())
flat_targets = targets.reshape(-1)
total = int(flat_targets.numel())
if total <= max_tokens:
return hidden.contiguous(), targets.contiguous(), total, total
# With-replacement sampling avoids building a full randperm each step; the sampled
# mean remains an unbiased estimator of the dense token CE mean.
idx = torch.randint(total, (max_tokens,), device=targets.device)
flat_hidden = hidden.reshape(total, hidden.size(-1))
return flat_hidden.index_select(0, idx).contiguous(), flat_targets.index_select(0, idx).contiguous(), int(max_tokens), total
def _choose_objectives(state, args, ar_weight, sat_weight, nat_weight, do_sat_periodic, do_nat_periodic):
mode = str(getattr(args, "dblock_objective_mode", "periodic") or "periodic").lower()
if mode != "stochastic":
return ar_weight > 0.0, sat_weight > 0.0 and do_sat_periodic, nat_weight > 0.0 and do_nat_periodic, "periodic"
choices = []
probs = []
if ar_weight > 0.0:
choices.append("ar")
probs.append(max(0.0, float(getattr(args, "dblock_ar_prob", 0.80))))
if sat_weight > 0.0 and not getattr(args, "ar_only", False):
choices.append("sat")
probs.append(max(0.0, float(getattr(args, "dblock_sat_prob", 0.10))))
if nat_weight > 0.0 and not getattr(args, "ar_only", False):
choices.append("nat")
probs.append(max(0.0, float(getattr(args, "dblock_nat_prob", 0.10))))
if not choices:
return False, False, False, "none"
total = sum(probs)
if total <= 0.0:
probs = [1.0 / len(choices) for _ in choices]
else:
probs = [p / total for p in probs]
picked = random.choices(choices, weights=probs, k=1)[0]
return picked == "ar", picked == "sat", picked == "nat", picked
def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
import nB300_agillm4 as M
prof = _profile_active(state, args)
_step_t = _profile_tic(prof)
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
_setup_t = _profile_tic(prof)
B = state["B"]
asg = state["assign"]
bs = state["bsig"]
T = ids.size(1)
use_layer_checkpoint = bool(getattr(args, "grad_checkpoint", False))
bi = _choose_block(state, args)
lo, hi = sorted([bs[bi], bs[bi + 1]])
layers = asg[bi]
sig = _sample_sigma(ids, lo, hi, args, state)
cs, co, ci = _edm_pre(sig)
w = _edm_w(sig, float(getattr(args, "dblock_edm_wmax", 5.0)))
SATB = M.SAT_BLOCK
ar_weight = float(getattr(args, "dblock_ar_weight", 1.0))
sat_weight = float(getattr(args, "dblock_sat_weight", 1.0))
nat_weight = float(getattr(args, "dblock_nat_weight", 1.0)) * float(getattr(args, "nat_loss_weight", 1.0))
do_sat_periodic = (not getattr(args, "ar_only", False)) and (
int(getattr(args, "sat_every", 1)) <= 1 or ((int(state.get("step", 0)) + 1) % int(getattr(args, "sat_every", 1)) == 0)
)
do_nat_periodic = (
nat_h is not None
and (not getattr(args, "ar_only", False))
and int(getattr(args, "nat_every", 1)) > 0
and (
int(getattr(args, "nat_every", 1)) <= 1
or ((int(state.get("step", 0)) + 1) % int(getattr(args, "nat_every", 1)) == 0)
)
)
run_ar, run_sat, run_nat, objective = _choose_objectives(
state, args, ar_weight, sat_weight, nat_weight, do_sat_periodic, do_nat_periodic
)
_profile_toc(state, "setup", _setup_t)
ar_val = 0.0
sat_val = 0.0
nat_val = 0.0
ar_raw_val = 0.0
sat_raw_val = 0.0
nat_raw_val = 0.0
if run_ar:
causal = M.causal_mask(T, structured=M.use_structured_masks(args))
_t = _profile_tic(prof)
with M.amp(args.amp):
emb = core.emb(ids)
zt = emb + sig[:, None, None] * torch.randn_like(emb)
h = ci * zt
for lpos, li in enumerate(layers):
h = _run_block(core.blocks[li], h, causal, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos, len(layers)), args)
Dn = core.ln(cs * zt + co * h)
_profile_toc(state, "ar_forward", _t)
_t = _profile_tic(prof)
ar_hidden, ar_targets, ar_used, ar_total = _sample_token_loss_inputs(
Dn[:, :-1], ids[:, 1:], int(getattr(args, "dblock_ar_loss_tokens", 0))
)
ar_raw = fused_ce(ar_hidden, ar_h.proj.weight, ar_targets)
ar_raw_val = float(ar_raw.detach())
ar = ar_weight * w * ar_raw
ar_val = float(ar.detach())
_profile_toc(state, "ar_ce", _t)
_t = _profile_tic(prof)
scaler.scale(ar).backward()
_profile_toc(state, "ar_backward", _t)
del causal, emb, zt, h, Dn, ar_hidden, ar_targets, ar_raw, ar, ar_used, ar_total
if run_sat:
smask = M.sat_mask(T, structured=M.use_structured_masks(args))
_t = _profile_tic(prof)
with M.amp(args.amp):
emb2 = core.emb(ids)
zt2 = emb2 + sig[:, None, None] * torch.randn_like(emb2)
h2 = ci * zt2
for lpos, li in enumerate(layers):
h2 = _run_block(core.blocks[li], h2, smask, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos, len(layers)), args)
Ds = core.ln(cs * zt2 + co * h2)
last = Ds[:, -SATB:]
_profile_toc(state, "sat_forward", _t)
_t = _profile_tic(prof)
sat_hidden, sat_targets, sat_used, sat_total = _sample_token_loss_inputs(
last, ids[:, 1 : SATB + 1], int(getattr(args, "dblock_sat_loss_tokens", 0))
)
with M.amp(args.amp):
satf = fused_ce(sat_hidden, sat_h.proj.weight, sat_targets)
satv = (
M.EMIT_LAMBDA
* F.cross_entropy(
sat_h.gate(Ds[:, 0].float()),
torch.ones(ids.size(0), dtype=torch.long, device=ids.device),
)
if sat_h.gate is not None
else 0.0
)
sat_raw = satf + satv
sat_raw_val = float(sat_raw.detach())
sat = sat_weight * w * sat_raw
_profile_toc(state, "sat_ce", _t)
sat_val = float(sat.detach())
_t = _profile_tic(prof)
scaler.scale(sat).backward()
_profile_toc(state, "sat_backward", _t)
del smask, emb2, zt2, h2, Ds, last, sat_hidden, sat_targets, satf, satv, sat_raw, sat
if run_nat:
ratio = min(max(float(getattr(args, "nat_mask_ratio", 0.5)), 0.05), 0.95)
nat_ids = M._nat_ids_for_training(ids, int(getattr(args, "nat_max_tokens", 0)))
_t = _profile_tic(prof)
with M.amp(args.amp):
nat_in = nat_ids.clone()
m = torch.rand(nat_ids.shape, device=nat_ids.device) < ratio
if not bool(m.any()):
m[..., -1] = True
nat_in[m] = M.BLANK
hn = core.emb(nat_in)
for lpos, li in enumerate(layers):
hn = _run_block(core.blocks[li], hn, None, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos, len(layers)), args)
Dnat = core.ln(hn)
_profile_toc(state, "nat_forward", _t)
_t = _profile_tic(prof)
nat_hidden = Dnat[m]
nat_targets = nat_ids[m]
nat_hidden, nat_targets, nat_used, nat_total = _sample_token_loss_inputs(
nat_hidden.unsqueeze(0), nat_targets.unsqueeze(0), int(getattr(args, "dblock_nat_loss_tokens", 0))
)
nat_raw = fused_ce(nat_hidden, nat_h.proj.weight, nat_targets)
nat_raw_val = float(nat_raw.detach())
nat = nat_weight * nat_raw
nat_val = float(nat.detach())
_profile_toc(state, "nat_ce", _t)
_t = _profile_tic(prof)
scaler.scale(nat).backward()
_profile_toc(state, "nat_backward", _t)
del nat_ids, nat_in, m, hn, Dnat, nat_hidden, nat_targets, nat_raw, nat, nat_used, nat_total
total_val = ar_val + sat_val + nat_val
raw_total_val = ar_raw_val + sat_raw_val + nat_raw_val
raw_count = int(bool(run_ar)) + int(bool(run_sat)) + int(bool(run_nat))
raw_avg_val = raw_total_val / max(1, raw_count)
if not math.isfinite(total_val):
opt.zero_grad(set_to_none=True)
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"[dblock] non-finite loss {total_val}; skipped optimizer step", flush=True)
_profile_toc(state, "step_total", _step_t)
_profile_step_done(state, args)
_update_stats(state, bi, total_val)
return total_val
_t = _profile_tic(prof)
scaler.unscale_(opt)
nn.utils.clip_grad_norm_([p for g in opt.param_groups for p in g["params"]], 1.0)
scaler.step(opt)
scaler.update()
opt.zero_grad(set_to_none=True)
_profile_toc(state, "opt_step", _t)
peak_alloc = None
peak_reserved = None
if torch.cuda.is_available():
peak_alloc = torch.cuda.max_memory_allocated() / (1024**3)
peak_reserved = torch.cuda.max_memory_reserved() / (1024**3)
_profile_toc(state, "step_total", _step_t)
_profile_step_done(state, args)
_update_stats(state, bi, total_val)
_maybe_log(
state,
args,
bi,
layers,
ar_val,
sat_val,
nat_val,
total_val,
peak_alloc,
peak_reserved,
objective=objective,
raw_avg=raw_avg_val,
raw_total=raw_total_val,
edm_weight=w,
)
return raw_avg_val
# ===== END dblocks_train.py =====
# ===== BEGIN nB300_agillm4.py =====
#!/usr/bin/env python3
# n.py - Joint AR+SAT+NAT Trainer with Expansion Ratio Testing
# Enhanced inference: checkpoint name, tok/s, UK time
import argparse, copy, json, math, pathlib, random, time, os, sys, threading, hashlib, re, subprocess
from pathlib import Path
from contextlib import nullcontext
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime, timezone
_ASCII_LOG_TRANSLATION = str.maketrans({
"\u2018": "'",
"\u2019": "'",
"\u201a": "'",
"\u201b": "'",
"\u201c": '"',
"\u201d": '"',
"\u201e": '"',
"\u201f": '"',
"\u2013": "-",
"\u2014": "-",
"\u2212": "-",
"\u2026": "...",
"\u00a0": " ",
})
def _ascii_log_text(text: str) -> str:
return str(text).translate(_ASCII_LOG_TRANSLATION).encode("ascii", "replace").decode("ascii")
class _AsciiLogStream:
def __init__(self, wrapped):
self._wrapped = wrapped
def write(self, text):
return self._wrapped.write(_ascii_log_text(text))
def flush(self):
return self._wrapped.flush()
def isatty(self):
return self._wrapped.isatty()
def fileno(self):
return self._wrapped.fileno()
@property
def encoding(self):
return "ascii"
def __getattr__(self, name):
return getattr(self._wrapped, name)
if (
not sys.stdout.isatty()
and os.environ.get("NB300_RAW_UNICODE_LOGS", "").lower() not in {"1", "true", "yes"}
):
sys.stdout = _AsciiLogStream(sys.stdout)
sys.stderr = _AsciiLogStream(sys.stderr)
STATUS_SCRIPT_PATH = Path(__file__).resolve()
STATUS_DEFAULT_LOG = STATUS_SCRIPT_PATH.parent / "train.log"
STATUS_DEFAULT_SAVE_DIR = STATUS_SCRIPT_PATH.parent / "ckpts_expansion"
_STATUS_PROGRESS_RE = re.compile(
r"^\[(?P<percent>\d+(?:\.\d+)?)%\]\s+"
r"(?P<seen>[\d,]+)/(?P<target>[\d,]+)\s+tok\s+\|\s+"
r"(?P<tok_s>[\d.]+)\s+tok/s\s+\|\s+"
r"loss=(?P<loss>-?[\d.]+)\s+B=(?P<batch>\d+)\s+L=(?P<block>\d+)"
r"(?:\s+step=(?P<step>\d+))?"
r"(?:\s+eta=(?P<eta>\S+))?"
r"(?:\s+elapsed=(?P<elapsed>\S+))?"
r"\s*$"
)
_STATUS_DELTA_RE = re.compile(r"\[delta\]\s+saved\s+(?P<name>\S+?\.pt)\s+\((?P<sha>[0-9a-f]+)\.\.\.\)")
_STATUS_STEP_RE = re.compile(r"step(?P<step>\d+)")
def _status_iso(ts: Optional[float]) -> Optional[str]:
if ts is None:
return None
return datetime.fromtimestamp(ts, tz=timezone.utc).astimezone().isoformat(timespec="seconds")
def _status_human_duration(seconds: Optional[float]) -> Optional[str]:
if seconds is None:
return None
total = max(0, int(seconds))
days, rem = divmod(total, 86400)
hours, rem = divmod(rem, 3600)
minutes, secs = divmod(rem, 60)
parts = []
if days:
parts.append(f"{days}d")
if hours or parts:
parts.append(f"{hours}h")
if minutes or parts:
parts.append(f"{minutes}m")
parts.append(f"{secs}s")
return " ".join(parts)
def _status_compact_duration(seconds: Optional[float]) -> str:
if seconds is None:
return "unknown"
try:
if not math.isfinite(float(seconds)):
return "unknown"
except Exception:
return "unknown"
total = max(0, int(seconds))
years, rem = divmod(total, 365 * 86400)
days, rem = divmod(rem, 86400)
hours, rem = divmod(rem, 3600)
minutes, secs = divmod(rem, 60)
if years:
return f"{years}y{days}d{hours}h"
if days:
return f"{days}d{hours}h{minutes}m"
if hours:
return f"{hours}h{minutes}m{secs}s"
if minutes:
return f"{minutes}m{secs}s"
return f"{secs}s"
def _status_format_int(value: Optional[int]) -> str:
return "?" if value is None else f"{value:,}"
def _status_parse_step(text: str) -> Optional[int]:
match = _STATUS_STEP_RE.search(text)
return int(match.group("step")) if match else None
def _status_resolve_ckpt_path(raw_path: str, base_dir: Path) -> Path:
ckpt_path = Path(raw_path)
return ckpt_path if ckpt_path.is_absolute() else (base_dir / ckpt_path).resolve()
def _status_read_cmdline(proc_dir: Path) -> Optional[List[str]]:
try:
data = (proc_dir / "cmdline").read_bytes().split(b"\0")
return [item.decode("utf-8", errors="ignore") for item in data if item]
except Exception:
return None
def _status_resolve_proc_arg(proc_dir: Path, raw_arg: str) -> Optional[Path]:
try:
arg_path = Path(raw_arg)
if arg_path.is_absolute():
return arg_path.resolve()
cwd = Path(os.readlink(proc_dir / "cwd"))
return (cwd / arg_path).resolve()
except Exception:
return None
def _status_proc_uptime(proc_dir: Path) -> Optional[float]:
try:
proc_uptime = float((Path("/proc") / "uptime").read_text().split()[0])
stat_text = (proc_dir / "stat").read_text()
after = stat_text[stat_text.rfind(")") + 2:].split()
start_ticks = float(after[19])
clock_ticks = os.sysconf(os.sysconf_names["SC_CLK_TCK"])
return max(0.0, proc_uptime - (start_ticks / clock_ticks))
except Exception:
return None
def _status_find_trainers(script_path: Path) -> List[Dict[str, Any]]:
matches: List[Dict[str, Any]] = []
for proc_dir in Path("/proc").iterdir():
if not proc_dir.name.isdigit():
continue
args = _status_read_cmdline(proc_dir)
if not args or "train" not in args:
continue
resolved_script = None
for arg in args:
if Path(arg).name != script_path.name:
continue
candidate = _status_resolve_proc_arg(proc_dir, arg)
if candidate == script_path:
resolved_script = candidate
break
if resolved_script is None:
continue
uptime_seconds = _status_proc_uptime(proc_dir)
try:
cwd = str(Path(os.readlink(proc_dir / "cwd")))
except Exception:
cwd = None
matches.append({
"pid": int(proc_dir.name),
"cmdline": " ".join(args),
"args": args,
"cwd": cwd,
"uptime_seconds": round(uptime_seconds, 3) if uptime_seconds is not None else None,
"uptime_human": _status_human_duration(uptime_seconds),
})
return sorted(matches, key=lambda item: item["pid"])
def _status_parse_progress_line(line: str) -> Optional[Dict[str, Any]]:
match = _STATUS_PROGRESS_RE.match(line.strip())
if not match:
return None
tok_per_sec = float(match.group("tok_s"))
loss = float(match.group("loss"))
return {
"raw_line": line.strip(),
"percent": float(match.group("percent")),
"seen_tokens": int(match.group("seen").replace(",", "")),
"target_tokens": int(match.group("target").replace(",", "")),
"tok_per_sec": int(tok_per_sec) if tok_per_sec.is_integer() else tok_per_sec,
"loss": loss,
"batch": int(match.group("batch")),
"block": int(match.group("block")),
"step": int(match.group("step")) if match.group("step") else None,
"eta": match.group("eta"),
"elapsed": match.group("elapsed"),
}
def _status_parse_delta_line(line: str) -> Optional[Dict[str, Any]]:
match = _STATUS_DELTA_RE.search(line)
if not match:
return None
name = match.group("name")
return {
"raw_line": line.strip(),
"name": name,
"step": _status_parse_step(name),
"sha_prefix": match.group("sha"),
"source": "log",
}
def _status_scan_log(log_path: Path) -> tuple[Dict[str, Any], Optional[Dict[str, Any]], Optional[Dict[str, Any]], List[str]]:
now = time.time()
info: Dict[str, Any] = {
"path": str(log_path),
"exists": log_path.exists(),
"mtime": None,
"mtime_iso": None,
"age_seconds": None,
"age_human": None,
"size_bytes": None,
}
warnings: List[str] = []
if not log_path.exists():
warnings.append(f"train log missing: {log_path}")
return info, None, None, warnings
try:
st = log_path.stat()
info["mtime"] = st.st_mtime
info["mtime_iso"] = _status_iso(st.st_mtime)
info["age_seconds"] = round(max(0.0, now - st.st_mtime), 3)
info["age_human"] = _status_human_duration(info["age_seconds"])
info["size_bytes"] = st.st_size
except Exception as exc:
warnings.append(f"failed to stat train log: {exc}")
last_progress = None
last_delta = None
try:
with log_path.open("r", encoding="utf-8", errors="ignore") as handle:
for raw_line in handle:
line = raw_line.rstrip("\n")
progress = _status_parse_progress_line(line)
if progress is not None:
last_progress = progress
delta = _status_parse_delta_line(line)
if delta is not None:
last_delta = delta
except Exception as exc:
warnings.append(f"failed to read train log: {exc}")
return info, last_progress, last_delta, warnings
def _status_latest_full_checkpoint(save_dir: Path, base_dir: Path) -> tuple[Dict[str, Any], List[str]]:
latest_path = save_dir / "latest.json"
info: Dict[str, Any] = {
"metadata_path": str(latest_path),
"exists": latest_path.exists(),
"raw_path": None,
"checkpoint_path": None,
"checkpoint_name": None,
"checkpoint_exists": None,
"step": None,
"checkpoint_mtime": None,
"checkpoint_mtime_iso": None,
}
warnings: List[str] = []
if not latest_path.exists():
warnings.append(f"latest.json missing: {latest_path}")
return info, warnings
try:
payload = json.loads(latest_path.read_text(encoding="utf-8"))
except Exception as exc:
warnings.append(f"failed to parse latest.json: {exc}")
return info, warnings
raw_path = payload.get("path")
info["raw_path"] = raw_path
info["step"] = payload.get("step")
if raw_path:
ckpt_path = _status_resolve_ckpt_path(raw_path, base_dir)
info["checkpoint_path"] = str(ckpt_path)
info["checkpoint_name"] = ckpt_path.name
info["checkpoint_exists"] = ckpt_path.exists()
if ckpt_path.exists():
try:
st = ckpt_path.stat()
info["checkpoint_mtime"] = st.st_mtime
info["checkpoint_mtime_iso"] = _status_iso(st.st_mtime)
except Exception as exc:
warnings.append(f"failed to stat full checkpoint: {exc}")
else:
warnings.append(f"latest.json points to missing checkpoint: {ckpt_path}")
return info, warnings
def _status_newest_delta(save_dir: Path) -> tuple[Optional[Dict[str, Any]], List[str]]:
warnings: List[str] = []
if not save_dir.exists():
warnings.append(f"save dir missing: {save_dir}")
return None, warnings
try:
candidates = [item for item in save_dir.glob("*_delta_step*.pt") if item.is_file()]
except Exception as exc:
warnings.append(f"failed to list delta checkpoints: {exc}")
return None, warnings
if not candidates:
warnings.append(f"no delta checkpoints found in {save_dir}")
return None, warnings
newest = max(candidates, key=lambda item: item.stat().st_mtime)
st = newest.stat()
return {
"path": str(newest),
"name": newest.name,
"step": _status_parse_step(newest.name),
"mtime": st.st_mtime,
"mtime_iso": _status_iso(st.st_mtime),
"size_bytes": st.st_size,
"source": "disk",
}, warnings
def _status_gpu_info() -> tuple[Optional[Dict[str, Any]], List[str]]:
warnings: List[str] = []
try:
result = subprocess.run(
[
"nvidia-smi",
"--query-gpu=name,utilization.gpu,memory.used,memory.total,temperature.gpu,power.draw",
"--format=csv,noheader,nounits",
],
capture_output=True,
text=True,
timeout=5,
check=False,
)
except FileNotFoundError:
return None, warnings
except Exception as exc:
warnings.append(f"failed to query GPU status: {exc}")
return None, warnings
if result.returncode != 0:
warnings.append(result.stderr.strip() or "nvidia-smi returned non-zero exit status")
return None, warnings
lines = [line.strip() for line in result.stdout.splitlines() if line.strip()]
if not lines:
return None, warnings
if len(lines) > 1:
warnings.append("multiple GPUs detected; reporting the first GPU only")
parts = [part.strip() for part in lines[0].split(",")]
if len(parts) != 6:
warnings.append(f"unexpected nvidia-smi format: {lines[0]}")
return None, warnings
def _parse_int(raw: str) -> Optional[int]:
try:
return int(float(raw))
except Exception:
return None
def _parse_float(raw: str) -> Optional[float]:
try:
return float(raw)
except Exception:
return None
return {
"name": parts[0],
"utilization_gpu": _parse_int(parts[1]),
"memory_used_mib": _parse_int(parts[2]),
"memory_total_mib": _parse_int(parts[3]),
"temperature_c": _parse_int(parts[4]),
"power_draw_w": _parse_float(parts[5]),
}, warnings
def _status_choose_delta(from_log: Optional[Dict[str, Any]], from_disk: Optional[Dict[str, Any]], warnings: List[str]) -> Optional[Dict[str, Any]]:
if from_log and from_disk:
log_step = from_log.get("step")
disk_step = from_disk.get("step")
if log_step is not None and disk_step is not None:
if log_step != disk_step:
warnings.append(
f"log delta step {log_step} and newest on-disk delta step {disk_step} differ; using the newer step"
)
if disk_step >= log_step:
merged = dict(from_disk)
merged["source"] = "disk+log" if disk_step == log_step else "disk"
if disk_step == log_step:
merged["sha_prefix"] = from_log.get("sha_prefix")
return merged
return dict(from_log)
return dict(from_disk)
if from_disk:
return dict(from_disk)
if from_log:
return dict(from_log)
return None
def _collect_status(log_path: Path, save_dir: Path) -> tuple[Dict[str, Any], int]:
checked_at = time.time()
requested_save_dir = save_dir.expanduser()
log_path = log_path.expanduser()
status: Dict[str, Any] = {
"checked_at": checked_at,
"checked_at_iso": _status_iso(checked_at),
"running": False,
"process": None,
"progress": None,
"delta_checkpoint": None,
"delta_from_log": None,
"delta_on_disk": None,
"latest_full_checkpoint": None,
"log": None,
"gpu": None,
"save_dir": {
"requested_path": str(requested_save_dir),
"path": str(requested_save_dir),
"exists": requested_save_dir.exists(),
"source": "requested",
},
"warnings": [],
}
warnings = status["warnings"]
matches = _status_find_trainers(STATUS_SCRIPT_PATH)
if len(matches) > 1:
status["error"] = "multiple active n.py train processes found"
status["processes"] = matches
return status, 1
if matches:
status["running"] = True
status["process"] = matches[0]
save_dir = requested_save_dir
if status["process"] and status["process"].get("cwd"):
proc_cwd = Path(status["process"]["cwd"])
alt_save_dir = (proc_cwd / requested_save_dir.name).resolve()
if alt_save_dir != requested_save_dir and alt_save_dir.exists():
requested_delta, _ = _status_newest_delta(requested_save_dir)
requested_full, _ = _status_latest_full_checkpoint(requested_save_dir, STATUS_SCRIPT_PATH.parent)
alt_delta, _ = _status_newest_delta(alt_save_dir)
alt_full, _ = _status_latest_full_checkpoint(alt_save_dir, proc_cwd)
requested_score = int(requested_delta is not None) + int(bool(requested_full.get("checkpoint_exists")))
alt_score = int(alt_delta is not None) + int(bool(alt_full.get("checkpoint_exists")))
if alt_score > requested_score:
save_dir = alt_save_dir
status["save_dir"] = {
"requested_path": str(requested_save_dir),
"path": str(save_dir),
"exists": save_dir.exists(),
"source": "process_cwd_fallback",
}
warnings.append(
f"using process cwd save dir fallback: {save_dir} (requested {requested_save_dir})"
)
log_info, progress, delta_from_log, log_warnings = _status_scan_log(log_path)
warnings.extend(log_warnings)
status["log"] = log_info
status["progress"] = progress
status["delta_from_log"] = delta_from_log
latest_base_dir = STATUS_SCRIPT_PATH.parent
if status["save_dir"].get("source") == "process_cwd_fallback" and status["process"] and status["process"].get("cwd"):
latest_base_dir = Path(status["process"]["cwd"])
latest_full, latest_warnings = _status_latest_full_checkpoint(save_dir, latest_base_dir)
warnings.extend(latest_warnings)
status["latest_full_checkpoint"] = latest_full
delta_on_disk, delta_warnings = _status_newest_delta(save_dir)
warnings.extend(delta_warnings)
status["delta_on_disk"] = delta_on_disk
status["delta_checkpoint"] = _status_choose_delta(delta_from_log, delta_on_disk, warnings)
gpu, gpu_warnings = _status_gpu_info()
warnings.extend(gpu_warnings)
status["gpu"] = gpu
if status["running"] and log_info.get("age_seconds") is not None and log_info["age_seconds"] > 600:
warnings.append(f"train log appears stale while trainer is running ({log_info['age_human']} old)")
if log_info.get("exists") and progress is None:
warnings.append("no parseable progress line found in train log")
latest_step = latest_full.get("step") if latest_full else None
delta_step = status["delta_checkpoint"].get("step") if status["delta_checkpoint"] else None
if latest_step is not None and delta_step is not None and latest_step < delta_step:
warnings.append(f"latest.json step {latest_step} lags newest delta step {delta_step}")
if not status["running"] and progress is None:
warnings.append("no active trainer process found")
return status, 0
def _format_status_text(status: Dict[str, Any]) -> str:
lines = [f"AGILLM status @ {status.get('checked_at_iso')}"]
if status.get("error"):
lines.append(f"Error: {status['error']}")
for proc in status.get("processes", []):
lines.append(f"- pid {proc.get('pid')}: {proc.get('cmdline')}")
return "\n".join(lines)
process = status.get("process")
if status.get("running") and process:
lines.append(f"Process: RUNNING | pid {process.get('pid')} | uptime {process.get('uptime_human') or 'unknown'}")
lines.append(f"Cmd: {process.get('cmdline')}")
else:
lines.append("Process: NOT RUNNING")
progress = status.get("progress")
if progress:
eta = progress.get("eta")
if not eta and progress.get("tok_per_sec"):
remaining = max(0, progress["target_tokens"] - progress["seen_tokens"])
eta = _status_compact_duration(remaining / float(progress["tok_per_sec"]))
lines.append(
"Progress: "
f"{progress['percent']:.1f}% | "
f"{_status_format_int(progress['seen_tokens'])}/{_status_format_int(progress['target_tokens'])} tok | "
f"{progress['tok_per_sec']} tok/s | loss {progress['loss']:.3f} | "
f"B={progress['batch']} L={progress['block']}"
+ (f" | step {progress['step']}" if progress.get("step") else "")
+ (f" | ETA {eta}" if eta else "")
)
else:
lines.append("Progress: unavailable")
log_info = status.get("log") or {}
if log_info.get("exists"):
lines.append(
f"Log: {log_info.get('path')} | updated {log_info.get('age_human') or 'unknown'} ago | "
f"mtime {log_info.get('mtime_iso')}"
)
else:
lines.append(f"Log: missing ({log_info.get('path')})")
delta = status.get("delta_checkpoint")
if delta:
line = f"Delta: {delta.get('name')} | step {delta.get('step')} | source {delta.get('source')}"
if delta.get("path"):
line += f" | {delta['path']}"
lines.append(line)
else:
lines.append("Delta: unavailable")
latest_full = status.get("latest_full_checkpoint") or {}
if latest_full.get("exists"):
lines.append(
f"Latest full: step {latest_full.get('step')} | {latest_full.get('checkpoint_path') or latest_full.get('raw_path')}"
)
else:
lines.append(f"Latest full: unavailable ({latest_full.get('metadata_path')})")
gpu = status.get("gpu")
if gpu:
lines.append(
f"GPU: {gpu.get('name')} | {gpu.get('utilization_gpu')}% | "
f"{gpu.get('memory_used_mib')}/{gpu.get('memory_total_mib')} MiB | "
f"{gpu.get('temperature_c')}C | {gpu.get('power_draw_w')} W"
)
warnings = status.get("warnings") or []
if warnings:
lines.append("Warnings:")
lines.extend(f"- {warning}" for warning in warnings)
return "\n".join(lines)
def _emit_status(log_path: Path, save_dir: Path, as_json: bool) -> int:
status, exit_code = _collect_status(log_path, save_dir)
if as_json:
print(json.dumps(status, indent=2, sort_keys=True))
else:
print(_format_status_text(status))
return exit_code
def _run_status_command(argv: List[str]) -> int:
parser = argparse.ArgumentParser(prog=f"{STATUS_SCRIPT_PATH.name} status", description="Read-only training status")
parser.add_argument("--json", dest="json_output", action="store_true", help="Emit machine-readable JSON")
parser.add_argument("--log", type=Path, default=STATUS_DEFAULT_LOG, help="Path to the training log")
parser.add_argument("--save_dir", type=Path, default=STATUS_DEFAULT_SAVE_DIR, help="Checkpoint directory")
args = parser.parse_args(argv)
return _emit_status(args.log, args.save_dir, args.json_output)
def _maybe_handle_status_fastpath() -> None:
if len(sys.argv) > 1 and sys.argv[1] == "status":
raise SystemExit(_run_status_command(sys.argv[2:]))
_maybe_handle_status_fastpath()
import torch
import torch.utils.checkpoint as torch_checkpoint
# SafeProgress - Claude-safe progress (discrete lines, not single growing line)
class SafeProgress:
def __init__(self, total, initial=0, unit="tok", print_every=100, print_every_sec=60):
self.total, self.n, self.unit = total, initial, unit
self.initial = initial
self.last_print, self.postfix = initial, {}
self.print_every = max(1, int(print_every))
self.print_every_sec = max(1, int(print_every_sec))
self.step = 0
self.last_print_step = 0
self.start_time = __import__('time').time()
self.last_print_time = self.start_time
def update(self, n=1):
self.n += n
self.step += 1
now = __import__('time').time()
if (
self.step == 1
or (self.step - self.last_print_step) >= self.print_every
or (now - self.last_print_time) >= self.print_every_sec
):
self._print(now)
self.last_print = self.n
self.last_print_step = self.step
self.last_print_time = now
def set_postfix(self, **kwargs): self.postfix = kwargs
def _print(self, now=None):
now = now or __import__('time').time()
elapsed = now - self.start_time
rate = (self.n - self.initial) / elapsed if elapsed > 0 else 0
pct = 100 * self.n / self.total if self.total > 0 else 0
pf = ' '.join(f"{k}={v}" for k,v in self.postfix.items())
remaining = max(0, self.total - self.n)
eta = _status_compact_duration(remaining / rate) if rate > 0 else "unknown"
elapsed_s = _status_compact_duration(elapsed)
print(
f"[{pct:.4f}%] {self.n:,}/{self.total:,} {self.unit} | "
f"{rate:.2f} tok/s | {pf} step={self.step} eta={eta} elapsed={elapsed_s}",
flush=True,
)
def close(self): self._print(); print("Done.", flush=True)
import torch.nn as nn
import torch.nn.functional as F
import signal
import os
from datasets import load_dataset, DownloadConfig
from transformers import AutoTokenizer, logging as hf_log
# from tqdm.auto import tqdm # DISABLED - kills Claude context
# ─────────────────────────────── HOT DATASET LOADING ───────────────────────────────
HOT_CONFIG_PATH = Path("/workspace/hot_config.json")
_hot_config_cache = {"mtime": 0, "data": {}}
def get_hot_config() -> dict:
"""Load hot_config.json with caching, return empty dict if missing"""
try:
if HOT_CONFIG_PATH.exists():
mtime = HOT_CONFIG_PATH.stat().st_mtime
if mtime > _hot_config_cache["mtime"]:
with open(HOT_CONFIG_PATH) as f:
_hot_config_cache["data"] = json.load(f)
_hot_config_cache["mtime"] = mtime
return _hot_config_cache["data"]
except Exception as e:
print(f"[hot_config] Error loading: {e}")
return {}
def get_hot_datasets(default_sources: str) -> str:
"""Get datasets from hot_config if present, else use default"""
cfg = get_hot_config()
if "datasets" in cfg and cfg["datasets"]:
hot_ds = cfg["datasets"]
if isinstance(hot_ds, list):
hot_ds = ",".join(hot_ds)
print(f"[hot_config] Using hot datasets: {hot_ds}")
return hot_ds
return default_sources
# DISABLED: # Auto-rotating log to prevent context-window suicide
# DISABLED: try:
# DISABLED: from rotating_log import install_rotating_log
# DISABLED: install_rotating_log()
# DISABLED: except ImportError:
# pass # Running without rotation
# ───────────────────────── ASCII Sanitizer ─────────────────────────
def _ascii_safe(s):
if not isinstance(s, str):
return s
return (s
.replace('\u2019', "'").replace('\u2018', "'")
.replace('\u201C', '"').replace('\u201D', '"')
.replace('\u2014', '-').replace('\u2013', '-')
.replace('\u2026', '...')
.replace('\u00A0', ' '))
# ───────────────────────── ANSI Colors ─────────────────────────
class Colors:
RESET = "\033[0m"
BOLD = "\033[1m"
PROMPT = "\033[36m"
GEN = "\033[0m"
INFO = "\033[90m"
WARN = "\033[93m"
# ───────────────────────── Globals ─────────────────────────
hf_log.set_verbosity_error()
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "deepseek-ai/DeepSeek-V4-Pro")
SYNTHETIC_TOKENIZER = os.environ.get("AGILLM_SYNTHETIC_TOKENIZER", "").lower() in {"1", "true", "yes"}
class _SyntheticTokenizer:
pad_token = "<|pad|>"
pad_token_id = 0
eos_token_id = 1
sep_token_id = 1
def __init__(self, vocab_size: int):
self.vocab_size = vocab_size
self.backend_tokenizer = self
def add_special_tokens(self, _tokens):
return 0
def get_vocab(self):
return {f"tok_{i}": i for i in range(self.vocab_size)}
def encode(self, text):
return [2 + (ord(ch) % max(1, self.vocab_size - 2)) for ch in str(text)]
def decode(self, ids, skip_special_tokens=True):
return " ".join(f"tok{int(i)}" for i in ids if not skip_special_tokens or int(i) > 1)
def to_str(self):
return json.dumps({"type": "synthetic", "vocab_size": self.vocab_size})
if SYNTHETIC_TOKENIZER:
tok = _SyntheticTokenizer(int(os.environ.get("AGILLM_SYNTHETIC_VOCAB", "8192")))
print(f"[tokenizer] synthetic tokenizer enabled vocab={tok.vocab_size}")
else:
_tok_src = os.environ.get("TOKENIZER_DIR", "/workspace/tokenizers/deepseek-v4-pro")
if not os.path.isdir(_tok_src):
_tok_src = TOKENIZER_ID
try:
tok = AutoTokenizer.from_pretrained(_tok_src, use_fast=True, trust_remote_code=True, local_files_only=True)
except Exception as _tok_exc:
print(f"[tokenizer] offline load from {_tok_src} failed ({_tok_exc}); network fallback {TOKENIZER_ID}", flush=True)
tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True)
if tok.pad_token is None:
tok.add_special_tokens({"pad_token": "<|pad|>"})
# ─── Fix tokenizer Ġ/▁ mismatch ───
# Some DeepSeek tokenizer releases use Ġ (U+0120) for space-prefixed tokens,
# but some transformers versions set the Metaspace pre-tokenizer to use
# ▁ (U+2581) instead, causing encode/decode to lose all spaces.
def _fix_tokenizer_space_mismatch(tokenizer):
try:
import json as _json
from tokenizers import Tokenizer as _Tokenizer
bt = tokenizer.backend_tokenizer
tj = _json.loads(bt.to_str())
pre = tj.get("pre_tokenizer", {})
needs_fix = (pre.get("type") == "Metaspace" and pre.get("replacement") == "\u2581")
if not needs_fix:
return
# Check if vocab actually uses Ġ (U+0120) for spaces
vocab = tj.get("model", {}).get("vocab", {})
has_gpt2_space = any(k.startswith("\u0120") for k in list(vocab.keys())[:500])
if not has_gpt2_space:
return
# Patch pre_tokenizer: ▁ -> Ġ
tj["pre_tokenizer"]["replacement"] = "\u0120"
# Patch decoder: ▁ -> Ġ in Replace step
for step in tj.get("decoder", {}).get("decoders", []):
if step.get("type") == "Replace":
pat = step.get("pattern", {})
if pat.get("String") == "\u2581":
pat["String"] = "\u0120"
# Rebuild backend tokenizer
fixed = _Tokenizer.from_str(_json.dumps(tj))
tokenizer.backend_tokenizer = fixed
# Verify fix
test_ids = tokenizer.encode("hello world")
test_dec = tokenizer.decode(test_ids, skip_special_tokens=True)
if "hello world" in test_dec:
print("[tokenizer] Fixed Ġ/▁ space mismatch")
else:
print(f"[tokenizer] WARNING: fix applied but decode test failed: {repr(test_dec)}")
except Exception as e:
print(f"[tokenizer] Could not fix space mismatch: {e}")
if not SYNTHETIC_TOKENIZER:
_fix_tokenizer_space_mismatch(tok)
# ─── Tokenizer startup health check ───
# Abort early if tokenizer can't roundtrip spaces — prevents silent data corruption
def _tokenizer_health_check(tokenizer):
import transformers as _tf
ver = _tf.__version__
print(f"[tokenizer] transformers={ver}, tokenizers={__import__('tokenizers').__version__}")
# Warn on known-bad versions
try:
from packaging.version import Version
if Version(ver) >= Version('5.0.0'):
print(f'[tokenizer] WARNING: transformers {ver} may have Metaspace bug — verify carefully')
except ImportError:
pass
# Roundtrip tests — must preserve spaces
tests = [
'Water boils at one hundred degrees',
'The quick brown fox jumps over the lazy dog',
'Hello world! This is a test sentence with spaces.',
]
for text in tests:
ids = tokenizer.encode(text)
decoded = tokenizer.decode(ids, skip_special_tokens=True)
if ' ' not in decoded:
print(f'[tokenizer] FATAL: Roundtrip lost all spaces!')
print(f' Input: {repr(text)}')
print(f' Encoded: {ids[:20]}...')
print(f' Decoded: {repr(decoded)}')
print(f'[tokenizer] ABORTING — fix tokenizer before training!')
sys.exit(1)
# Check decoded is reasonably close to input
if text.lower().split()[:3] != decoded.lower().split()[:3]:
print(f'[tokenizer] WARNING: Roundtrip diverged:')
print(f' Input: {repr(text[:60])}')
print(f' Decoded: {repr(decoded[:60])}')
print(f'[tokenizer] Health check PASSED — spaces preserved in roundtrip')
if not SYNTHETIC_TOKENIZER:
_tokenizer_health_check(tok)
VOCAB, BLANK, EOS = (
max(tok.get_vocab().values()) + 1,
int(getattr(tok, "pad_token_id", 0) or 0),
tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id
)
# ───────────────────────── PRESETS ─────────────────────────
PRESETS: Dict[str, Dict[str, int]] = {
"femto_1x": dict(d=16, layers=1, heads=1, rank=16),
"femto_12x": dict(d=16, layers=1, heads=1, rank=192),
"femto_24x": dict(d=16, layers=1, heads=1, rank=384),
"pico_1x": dict(d=32, layers=1, heads=2, rank=16),
"pico_3x": dict(d=32, layers=1, heads=2, rank=48),
"pico_6x": dict(d=32, layers=1, heads=2, rank=96),
"pico_12x": dict(d=32, layers=1, heads=2, rank=192),
"pico_24x": dict(d=32, layers=1, heads=2, rank=384),
"pico_48x": dict(d=32, layers=1, heads=2, rank=768),
"nano_1x": dict(d=64, layers=2, heads=4, rank=16),
"nano_3x": dict(d=64, layers=2, heads=4, rank=48),
"nano_6x": dict(d=64, layers=2, heads=4, rank=96),
"nano_12x": dict(d=64, layers=2, heads=4, rank=192),
"nano_24x": dict(d=64, layers=2, heads=4, rank=384),
"nano_48x": dict(d=64, layers=2, heads=4, rank=768),
"nano_96x": dict(d=64, layers=2, heads=4, rank=1536),
"micro_3x": dict(d=128, layers=4, heads=8, rank=48),
"micro_6x": dict(d=128, layers=4, heads=8, rank=96),
"micro_12x": dict(d=128, layers=4, heads=8, rank=192),
"micro_24x": dict(d=128, layers=4, heads=8, rank=384),
"small": dict(d=512, layers=8, heads=16, rank=64),
"smallx2": dict(d=512, layers=16, heads=16, rank=64),
"base": dict(d=768, layers=12, heads=24, rank=96),
"base18": dict(d=768, layers=18, heads=24, rank=96),
"large": dict(d=1024, layers=24, heads=16, rank=128),
# AGILLM-4 tiers. These are intentionally above the ~700M AGILLM-3 size.
# Approx dense parameter count with the current untied embedding+AR+SAT+NAT heads:
# agillm4_floor ~= 1.21B, agillm4_main ~= 1.70B, agillm4_big ~= 2.40B.
"agillm4_floor": dict(d=1280, layers=28, heads=20, rank=160),
"agillm4_main": dict(d=1536, layers=32, heads=24, rank=192),
"agillm4_big": dict(d=1792, layers=36, heads=28, rank=224),
}
DEFAULT_BLOCK = 1122
DEFAULT_BATCH = 4
SAT_BLOCK = 2
LR_CORE, LR_HEAD = 5e-5, 2e-4
EMIT_LAMBDA = 0.1
DEFAULT_SAVE_SEC = 24 * 3600
DEFAULT_DELTA_STEPS = 100000 # lightweight weight-only save every N steps
DEFAULT_MAX_DELTAS = 5 # keep last N deltas (older pruned after full save)
CKDIR = pathlib.Path("ckpts_expansion")
DEFAULT_PRETRAIN_SOURCES = "LLM360/TxT360,OpenTransformer/goddess-crawl,OpenTransformer/agillm-crawl-data,OpenTransformer/web-crawl-2026,OpenTransformer/web-crawl-clean-v2,OpenTransformer/scraped-web-data,OpenTransformer/turbo-crawl,OpenTransformer/sft-data-clean,OpenTransformer/web-crawl-v1,HuggingFaceFW/fineweb,wikimedia/wikipedia:20231101.en,allenai/c4:en,EleutherAI/proof-pile-2"
DEFAULT_AFTER_SFT_SOURCES = "mlabonne/opc-sft-stage2-chat,HuggingFaceH4/ultrachat_200k@train_sft"
DEFAULT_AFTER_SFT_BLOCK = 768
DEFAULT_ATTN_BACKEND = os.environ.get("AGILLM_ATTN_BACKEND", "manual")
def _env_int(name: str, default: int) -> int:
try:
return int(os.environ.get(name, default))
except (TypeError, ValueError):
return default
DEFAULT_SUBLINEAR_WINDOW = _env_int("AGILLM_SUBLINEAR_WINDOW", 256)
DEFAULT_SUBLINEAR_STRIDE = _env_int("AGILLM_SUBLINEAR_STRIDE", 64)
DEFAULT_SUBLINEAR_MAX_ANCHORS = _env_int("AGILLM_SUBLINEAR_MAX_ANCHORS", 256)
DEFAULT_SUBLINEAR_CHUNK = _env_int("AGILLM_SUBLINEAR_CHUNK", 128)
DEFAULT_SUBLINEAR_SINKS = _env_int("AGILLM_SUBLINEAR_SINKS", 4)
DEFAULT_SUBLINEAR_RECENT_ANCHORS = _env_int("AGILLM_SUBLINEAR_RECENT_ANCHORS", -1) # -1 = half of max anchors
DEFAULT_SUBLINEAR_POOLED_LANDMARKS = bool(_env_int("AGILLM_SUBLINEAR_POOLED_LANDMARKS", 0))
DEFAULT_ANCHOR_MEMORY = bool(_env_int("AGILLM_ANCHOR_MEMORY", 0))
DEFAULT_ANCHOR_STRIDE = _env_int("AGILLM_ANCHOR_STRIDE", 256)
DEFAULT_ANCHOR_MAX = _env_int("AGILLM_ANCHOR_MAX", 2048)
DEFAULT_ANCHOR_POSITION = _env_int("AGILLM_ANCHOR_POSITION", -1) # -1 = stack middle
DEFAULT_KV_BUFFER = bool(_env_int("AGILLM_KV_BUFFER", 0))
DEFAULT_MOE_FFN = bool(_env_int("AGILLM_MOE_FFN", 0))
DEFAULT_MOE_EXPERTS = _env_int("AGILLM_MOE_EXPERTS", 4)
DEFAULT_MOE_TOP_K = _env_int("AGILLM_MOE_TOP_K", 1)
DEFAULT_MOE_MLP_MULT = _env_int("AGILLM_MOE_MLP_MULT", 4)
AGILLM4_TOKEN_PARAM_RATIO = 100.0
# ───────────────────────── UK Time Helper ─────────────────────────
def get_uk_time() -> str:
utc_now = datetime.now(timezone.utc)
year = utc_now.year
march_last = datetime(year, 3, 31, 1, 0, tzinfo=timezone.utc)
while march_last.weekday() != 6:
march_last = march_last.replace(day=march_last.day - 1)
oct_last = datetime(year, 10, 31, 1, 0, tzinfo=timezone.utc)
while oct_last.weekday() != 6:
oct_last = oct_last.replace(day=oct_last.day - 1)
if march_last <= utc_now < oct_last:
uk_offset = 1
tz_name = "BST"
else:
uk_offset = 0
tz_name = "GMT"
from datetime import timedelta
uk_time = utc_now + timedelta(hours=uk_offset)
return uk_time.strftime(f'%Y-%m-%d %H:%M:%S {tz_name}')
# ───────────────────────── Utilities ─────────────────────────
def rng_state():
if DEV.type == "cuda":
try:
return torch.cuda.get_rng_state(DEV)
except TypeError:
return torch.cuda.get_rng_state()
return torch.get_rng_state()
def _is_probably_ckpt(path: pathlib.Path) -> bool:
try:
return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20)
except Exception:
return False
def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None:
try:
if path.is_dir():
cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)],
key=lambda p: p.stat().st_mtime, reverse=True)
return cands[0] if cands else None
if path.suffix == ".tmp":
solid = path.with_suffix("")
return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent)
return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent)
except Exception:
return None
def _try_load(path: pathlib.Path, map_location="cpu"):
try:
return torch.load(path, map_location="cpu")
except Exception as e:
print(f"[ckpt-skip] {path} not usable: {e}")
return None
def _prune_checkpoints(save_dir: pathlib.Path, phase_name: str, max_ckpts: int):
if max_ckpts is None or max_ckpts <= 0:
return
try:
pattern = f"{phase_name}_step*.pt"
ckpts = sorted(
[p for p in save_dir.glob(pattern) if _is_probably_ckpt(p)],
key=lambda p: p.stat().st_mtime
)
excess = len(ckpts) - max_ckpts
if excess > 0:
for p in ckpts[:excess]:
try:
p.unlink()
print(f" [prune] deleted old {p.name}")
except Exception:
pass
except Exception as e:
print(f"[ckpt-prune] error: {e}")
def print_expansion_info(cfg: dict, tie_weights: bool = False, plain: bool = False):
d_k = cfg["d"] // cfg["heads"]
rank = cfg["rank"]
ratio = rank / d_k
regime = "COMPRESSION" if ratio < 1 else ("IDENTITY" if ratio == 1 else "EXPANSION")
tie_str = "YES" if tie_weights else "NO"
if plain:
print("[attention_config]")
print(f"d_model={cfg['d']} heads={cfg['heads']} d_k={d_k}")
print(f"layers={cfg['layers']} tie_weights={tie_str}")
print(f"rank={rank} ratio={ratio:.1f}x regime={regime}")
return
print(f"┌─────────────────────────────────────────┐")
print(f"│ TUNEABLE ATTENTION CONFIG │")
print(f"├─────────────────────────────────────────┤")
print(f"│ d_model: {cfg['d']:4d} heads: {cfg['heads']:2d} d_k: {d_k:3d} │")
print(f"│ layers: {cfg['layers']:4d} tie_weights: {tie_str:3s} │")
print(f"│ rank: {rank:4d} ratio: {ratio:.1f}x [{regime:11s}] │")
print(f"└─────────────────────────────────────────┘")
# ───────────────────────── AMP helper ─────────────────────────
try:
from torch.amp import autocast as _ac, GradScaler
except ImportError:
from torch.cuda.amp import autocast as _ac, GradScaler
def _auto_amp_dtype():
if DEV.type == "cuda":
try:
if torch.cuda.is_bf16_supported(): return torch.bfloat16
return torch.float16
except Exception: return torch.float16
return torch.float32
def amp(enabled: bool):
return nullcontext() if not (enabled and DEV.type == "cuda") else _ac(device_type="cuda", dtype=_auto_amp_dtype())
def _needs_grad_scaler() -> bool:
return bool(DEV.type == "cuda" and _auto_amp_dtype() == torch.float16)
# ───────────────────────── Chat & Data Stream ─────────────────────────
def _coerce_role(r: str) -> str:
r = (r or "").lower()
if r in {"user", "human", "customer"}: return "user"
if r in {"assistant", "gpt", "bot"}: return "assistant"
if r in {"system", "context"}: return "system"
return r or "user"
def _chat_content(m: dict) -> str:
content = m.get("content", m.get("text", m.get("value", "")))
return content if isinstance(content, str) else ""
def _chat_role(m: dict) -> str:
return _coerce_role(m.get("role", m.get("from", m.get("speaker", ""))))
def _fallback_chat_template(messages: list[dict], add_generation_prompt: bool) -> str:
parts = []
for m in messages:
role = _chat_role(m)
content = _chat_content(m).strip()
if not content:
continue
if role == "system":
parts.append(f"System: {content}")
elif role == "assistant":
parts.append(f"Assistant: {content}")
else:
parts.append(f"User: {content}")
if add_generation_prompt and (not parts or not parts[-1].startswith("Assistant:")):
parts.append("Assistant:")
return "\n".join(parts)
def _render_chat_text_from_ex(ex: dict, messages_key: str, add_generation_prompt: bool) -> Optional[str]:
msgs = ex.get(messages_key)
if msgs is None:
for alt in ("conversations", "dialog", "turns"):
if isinstance(ex.get(alt), list):
msgs = ex[alt]; break
if isinstance(msgs, list) and msgs and isinstance(msgs[0], dict):
norm = []
for m in msgs:
content = _chat_content(m)
if not isinstance(content, str) or not content:
continue
norm.append({"role": _chat_role(m), "content": content})
if not norm: return None
try:
return tok.apply_chat_template(norm, tokenize=False, add_generation_prompt=add_generation_prompt)
except Exception:
return _fallback_chat_template(norm, add_generation_prompt)
for a, b in (("prompt", "response"), ("instruction", "output"), ("question", "answer")):
if isinstance(ex.get(a), str) and isinstance(ex.get(b), str):
return f"User: {ex[a]}\nAssistant: {ex[b]}"
return None
def _parse_dataset_ref(ds_name: str):
split = "train"
ref = ds_name
if "@" in ref:
ref, split = ref.rsplit("@", 1)
split = split or "train"
if ":" in ref:
base, config = ref.split(":", 1)
else:
base, config = ref, None
return base, config, split
def _open_stream_one(ds_name: str, seed: int, streaming: bool = True):
dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True)
base, config, split = _parse_dataset_ref(ds_name)
if not streaming:
print(f"[download] Downloading {ds_name} (non-streaming)...")
if base == "json":
data_files = {"train": config}
ds = load_dataset("json", data_files=data_files, split=split, streaming=streaming, download_config=dc)
else:
ds = load_dataset(base, config, split=split, streaming=streaming, download_config=dc) if config else \
load_dataset(base, split=split, streaming=streaming, download_config=dc)
if streaming:
return iter(ds.shuffle(buffer_size=1000, seed=seed))
else:
print(f"[download] Got {len(ds):,} examples. Shuffling...")
ds = ds.shuffle(seed=seed)
return iter(ds)
def token_stream(ds_names: str, target: int, seed: int = 42,
chat: bool = False, chat_messages_key: str = "messages",
sft_add_generation_prompt: bool = False, dataset_field_text: str = "text",
streaming: bool = True):
ds_names = get_hot_datasets(ds_names) # HOT LOAD
sources = [s.strip() for s in ds_names.split(",") if s.strip()]
if not sources: return
src_idx = 0; emitted = 0; it = None; attempts = 0; backoff_base = 2.0
while emitted < target:
try:
if it is None: it = _open_stream_one(sources[src_idx], seed, streaming=streaming)
ex = next(it)
text = None
if isinstance(ex, dict):
if chat:
text = _render_chat_text_from_ex(ex, chat_messages_key, sft_add_generation_prompt)
if text is None:
if dataset_field_text and isinstance(ex.get(dataset_field_text), str):
text = ex[dataset_field_text]
elif isinstance(ex.get("text"), str):
text = ex["text"]
if not isinstance(text, str):
attempts = 0; continue
enc = tok.encode(text)
if EOS is not None and (len(enc) == 0 or enc[-1] != EOS):
enc = enc + [EOS]
for t in enc:
yield t
emitted += 1
if emitted >= target: return
attempts = 0
except StopIteration:
it = None; src_idx = (src_idx + 1) % len(sources)
except Exception as e:
attempts += 1
sleep_s = min(60.0, backoff_base ** min(attempts, 6))
print(f"[stream-retry] {sources[src_idx]} error: {type(e).__name__}, sleeping {sleep_s:.1f}s")
time.sleep(sleep_s); it = None
if attempts % 5 == 0 and len(sources) > 1:
src_idx = (src_idx + 1) % len(sources)
# ───────────────────────── ALiBi ─────────────────────────
def _alibi_slopes(n_heads: int):
def pow2slopes(n):
start = 2 ** (-2 ** -(math.log2(n) - 3))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(n_heads).is_integer(): vals = pow2slopes(n_heads)
else:
closest = 2 ** math.floor(math.log2(n_heads))
vals = pow2slopes(closest)
extra = pow2slopes(2 * closest)
vals += extra[0::2][: n_heads - closest]
return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
def alibi_bias(n_heads: int, n_tokens: int):
i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
dist = (j - i).clamp_min(0)
return -_alibi_slopes(n_heads) * dist
class StructuredAttentionMask:
"""Symbolic attention rules for sublinear attention.
Dense masks are O(T^2). This object carries the rule so sublinear attention can
apply it only to the gathered local/anchor candidate keys: O(T * candidates).
"""
__slots__ = ("kind", "q_len", "k_len", "query_base", "block")
def __init__(self, kind: str, q_len: int, k_len: int = None, query_base: int = 0, block: int = 1):
self.kind = (kind or "none").lower()
self.q_len = int(q_len)
self.k_len = int(k_len if k_len is not None else q_len)
self.query_base = int(query_base)
self.block = max(1, int(block))
def to_dense(self, device=None, dtype=torch.float32):
device = device or DEV
if self.kind in {"none", "nat", "bidirectional", "unrestricted"}:
return None
q_pos = torch.arange(self.query_base, self.query_base + self.q_len, device=device, dtype=torch.long).view(self.q_len, 1)
k_pos = torch.arange(self.k_len, device=device, dtype=torch.long).view(1, self.k_len)
if self.kind == "causal":
allow = k_pos <= q_pos
elif self.kind in {"sat", "block_causal", "block-causal"}:
allow = (k_pos // self.block) <= (q_pos // self.block)
else:
raise ValueError(f"unknown structured attention mask kind: {self.kind}")
zeros = torch.zeros((self.q_len, self.k_len), device=device, dtype=dtype)
neg = torch.full_like(zeros, float("-inf"))
return torch.where(allow, zeros, neg).unsqueeze(0).unsqueeze(0)
def _is_structured_attention_mask(mask) -> bool:
return isinstance(mask, StructuredAttentionMask)
def use_structured_masks(args=None, backend: str = None) -> bool:
backend = (backend or getattr(args, "attn_backend", "") or "").lower()
return backend == "sublinear" and not bool(getattr(args, "no_structured_masks", False))
# ───────────────────────── Model components ─────────────────────────
class KVBuffer:
"""Preallocated K/V cache for decode. Replaces torch.cat-based growth.
Layout matches MHA-internal head-major shape [B, H, T, d_k]. Caller sizes
once; each ``append`` writes ``length:length+n`` slots in place and grows
``length``. ``view()`` returns slices of the live region so attention sees
only filled positions.
"""
__slots__ = ("k", "v", "length", "capacity")
def __init__(
self,
batch: int,
heads: int,
capacity: int,
d_k: int,
device,
dtype,
):
self.k = torch.empty(batch, heads, capacity, d_k, device=device, dtype=dtype)
self.v = torch.empty(batch, heads, capacity, d_k, device=device, dtype=dtype)
self.length = 0
self.capacity = capacity
def append(self, k_new: torch.Tensor, v_new: torch.Tensor):
n = k_new.size(2)
end = self.length + n
if end > self.capacity:
raise RuntimeError(
f"KVBuffer overflow: length={self.length} + n={n} > capacity={self.capacity}"
)
self.k[:, :, self.length:end].copy_(k_new)
self.v[:, :, self.length:end].copy_(v_new)
self.length = end
def view(self):
return self.k[:, :, :self.length], self.v[:, :, :self.length]
class TuneableAttentionMHA(nn.Module):
def __init__(
self,
d: int,
h: int,
r: int,
use_relpos: bool = True,
attn_backend: str = DEFAULT_ATTN_BACKEND,
sublinear_window: int = DEFAULT_SUBLINEAR_WINDOW,
sublinear_stride: int = DEFAULT_SUBLINEAR_STRIDE,
sublinear_max_anchors: int = DEFAULT_SUBLINEAR_MAX_ANCHORS,
sublinear_chunk: int = DEFAULT_SUBLINEAR_CHUNK,
sublinear_sinks: int = DEFAULT_SUBLINEAR_SINKS,
sublinear_recent_anchors: int = DEFAULT_SUBLINEAR_RECENT_ANCHORS,
sublinear_pooled_landmarks: bool = DEFAULT_SUBLINEAR_POOLED_LANDMARKS,
):
super().__init__()
assert d % h == 0
self.h, self.dk, self.r = h, d // h, r
self.use_relpos = use_relpos
self.attn_backend = (attn_backend or "manual").lower()
self.sublinear_window = max(1, int(sublinear_window))
self.sublinear_stride = max(0, int(sublinear_stride))
self.sublinear_max_anchors = max(0, int(sublinear_max_anchors))
self.sublinear_chunk = max(1, int(sublinear_chunk))
self.sublinear_sinks = max(0, int(sublinear_sinks))
recent = int(sublinear_recent_anchors)
if recent < 0:
recent = self.sublinear_max_anchors // 2
self.sublinear_recent_anchors = min(max(0, recent), self.sublinear_max_anchors)
self.sublinear_pooled_landmarks = bool(sublinear_pooled_landmarks)
# Exact n1 harvest: one fused QKV projection is mathematically the same
# as three independent bias-free Linear(d, d) projections with their
# weights stacked along out_features.
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.U = nn.Parameter(torch.randn(self.dk, r))
nn.init.orthogonal_(self.U)
self.proj = nn.Linear(h * self.dk, d, bias=False)
self.drop = nn.Dropout(0.1)
# Exact n1 harvest: for expansion ranks, (q @ U) @ (k @ U).T is
# q @ (U @ U.T) @ k.T. This keeps score/cache width at d_k with no
# quality change. Inference caches the metric and training recomputes
# it so gradients through U are unchanged.
self._metric_cache: Optional[torch.Tensor] = None
self._metric_cache_ver: int = -1
self._metric_cache_param_id: int = -1
self._metric_cache_data_ptr: int = -1
self._metric_cache_shape: Tuple[int, int] = (-1, -1)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
qkv_key = prefix + "qkv.weight"
if qkv_key not in state_dict:
qk = prefix + "q.weight"
kk = prefix + "k.weight"
vk = prefix + "v.weight"
if qk in state_dict and kk in state_dict and vk in state_dict:
fused = _cat_legacy_weight_blocks([state_dict[qk], state_dict[kk], state_dict[vk]])
if fused is not None:
state_dict[qkv_key] = fused
state_dict.pop(qk)
state_dict.pop(kk)
state_dict.pop(vk)
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs,
)
def _proj_qk(self, x):
B, N, _ = x.shape
return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U)
def _reshape_v(self, x):
B, N, _ = x.shape
return x.view(B, N, self.h, self.dk).transpose(1, 2)
def _reshape_heads(self, x):
B, N, _ = x.shape
return x.view(B, N, self.h, self.dk).transpose(1, 2)
def _get_metric(self) -> torch.Tensor:
if torch.is_grad_enabled():
return self.U @ self.U.T
cur_ver = self.U._version
cur_param_id = id(self.U)
cur_data_ptr = int(self.U.data_ptr())
cur_shape = tuple(self.U.shape)
cache = self._metric_cache
if (
cache is None
or cache.dtype != self.U.dtype
or cache.device != self.U.device
or self._metric_cache_ver != cur_ver
or self._metric_cache_param_id != cur_param_id
or self._metric_cache_data_ptr != cur_data_ptr
or self._metric_cache_shape != cur_shape
):
cache = (self.U @ self.U.T).detach()
self._metric_cache = cache
self._metric_cache_ver = cur_ver
self._metric_cache_param_id = cur_param_id
self._metric_cache_data_ptr = cur_data_ptr
self._metric_cache_shape = cur_shape
return cache
def train(self, mode: bool = True):
if mode:
self._metric_cache = None
self._metric_cache_ver = -1
self._metric_cache_param_id = -1
self._metric_cache_data_ptr = -1
self._metric_cache_shape = (-1, -1)
return super().train(mode)
def _structured_valid(self, attn_mask, q_pos, idx):
if not _is_structured_attention_mask(attn_mask):
return None
kind = attn_mask.kind
if kind in {"none", "nat", "bidirectional", "unrestricted"}:
return torch.ones_like(idx, dtype=torch.bool)
if kind == "causal":
return idx <= q_pos[:, None]
if kind in {"sat", "block_causal", "block-causal"}:
block = max(1, int(attn_mask.block))
return (idx // block) <= (q_pos[:, None] // block)
raise ValueError(f"unknown structured attention mask kind: {kind}")
def _sublinear_anchor_positions(self, k_len: int, device):
anchor_start = self.sublinear_stride - 1
if self.sublinear_stride <= 0 or self.sublinear_max_anchors <= 0 or anchor_start >= k_len:
anchors = torch.empty(0, device=device, dtype=torch.long)
else:
all_anchors = torch.arange(anchor_start, k_len, self.sublinear_stride, device=device, dtype=torch.long)
if all_anchors.numel() <= self.sublinear_max_anchors:
anchors = all_anchors
else:
recent_budget = min(self.sublinear_recent_anchors, self.sublinear_max_anchors)
span_budget = max(0, self.sublinear_max_anchors - recent_budget)
parts = []
if span_budget > 0:
span_sel = torch.linspace(0, all_anchors.numel() - 1, span_budget, device=device).round().long().unique()
parts.append(all_anchors[span_sel])
if recent_budget > 0:
parts.append(all_anchors[-recent_budget:])
anchors = torch.cat(parts).unique() if parts else torch.empty(0, device=device, dtype=torch.long)
if self.sublinear_sinks > 0 and k_len > 0:
sinks = torch.arange(min(self.sublinear_sinks, k_len), device=device, dtype=torch.long)
anchors = torch.cat([sinks, anchors]).unique() if anchors.numel() else sinks
return anchors
def _sublinear_attention(self, q, k, v, attn_mask=None, rel_bias_tokens=None):
"""Local-window + landmark attention: O(N * (window + N/stride))."""
bsz, heads, q_len, _ = q.shape
k_len = k.size(2)
device = q.device
query_base = max(0, k_len - q_len)
outputs = []
scale = 1.0 / math.sqrt(self.dk)
slopes = None
if self.use_relpos and rel_bias_tokens is not None:
slopes = _alibi_slopes(self.h).to(device=device, dtype=torch.float32)
anchors = self._sublinear_anchor_positions(k_len, device)
anchor_k = anchor_v = None
if anchors.numel() and self.sublinear_pooled_landmarks and self.sublinear_stride > 1:
# Optional pooled landmarks: each global anchor summarizes its stride segment.
# This is off by default because it adds cumsum work; enable after benchmarking.
ends = anchors + 1
starts = (ends - self.sublinear_stride).clamp_min(0)
zero_k = k.new_zeros(k.size(0), k.size(1), 1, k.size(3))
zero_v = v.new_zeros(v.size(0), v.size(1), 1, v.size(3))
prefix_k = torch.cat([zero_k, k.cumsum(dim=2)], dim=2)
prefix_v = torch.cat([zero_v, v.cumsum(dim=2)], dim=2)
denom = (ends - starts).to(dtype=k.dtype).view(1, 1, -1, 1).clamp_min(1)
anchor_k = (prefix_k[:, :, ends, :] - prefix_k[:, :, starts, :]) / denom
anchor_v = (prefix_v[:, :, ends, :] - prefix_v[:, :, starts, :]) / denom
offsets = torch.arange(
-self.sublinear_window,
self.sublinear_window + 1,
device=device,
dtype=torch.long,
)
for q_start in range(0, q_len, self.sublinear_chunk):
q_end = min(q_len, q_start + self.sublinear_chunk)
cur = q_end - q_start
q_pos = torch.arange(query_base + q_start, query_base + q_end, device=device, dtype=torch.long)
local_raw = q_pos[:, None] + offsets[None, :]
local_valid = (local_raw >= 0) & (local_raw < k_len)
local_idx = local_raw.clamp(0, max(0, k_len - 1))
k_local = k[:, :, local_idx, :]
v_local = v[:, :, local_idx, :]
if anchors.numel():
anchor_idx = anchors.view(1, -1).expand(cur, -1)
local_lo = (q_pos - self.sublinear_window).clamp_min(0).view(-1, 1)
local_hi = (q_pos + self.sublinear_window).clamp_max(max(0, k_len - 1)).view(-1, 1)
# Drop anchor copies already present in the local window; duplicates bias softmax mass.
anchor_valid = (anchor_idx < local_lo) | (anchor_idx > local_hi)
idx = torch.cat([local_idx, anchor_idx], dim=1)
valid = torch.cat([local_valid, anchor_valid], dim=1)
if anchor_k is not None and anchor_v is not None:
k_anchor = anchor_k.unsqueeze(2).expand(-1, -1, cur, -1, -1)
v_anchor = anchor_v.unsqueeze(2).expand(-1, -1, cur, -1, -1)
else:
k_anchor = k[:, :, anchor_idx, :]
v_anchor = v[:, :, anchor_idx, :]
k_sel = torch.cat([k_local, k_anchor], dim=-2)
v_sel = torch.cat([v_local, v_anchor], dim=-2)
else:
idx = local_idx
valid = local_valid
k_sel = k_local
v_sel = v_local
structured_valid = self._structured_valid(attn_mask, q_pos, idx)
if structured_valid is not None:
valid = valid & structured_valid
scores = (q[:, :, q_start:q_end, :].unsqueeze(-2) * k_sel).sum(dim=-1) * scale
if slopes is not None:
dist = (q_pos.view(1, 1, cur, 1) - idx.view(1, 1, cur, -1)).abs().to(torch.float32)
scores = scores + (-slopes * dist).to(scores.dtype)
if torch.is_tensor(attn_mask) and attn_mask.size(-1) == k_len and attn_mask.size(-2) >= q_end:
mask_q = attn_mask[..., q_start:q_end, :]
gather_idx = idx.view(1, 1, cur, -1).expand(mask_q.size(0), mask_q.size(1), cur, idx.size(1))
scores = scores + torch.gather(mask_q, -1, gather_idx)
scores = scores.masked_fill(~valid.view(1, 1, cur, -1), float("-inf"))
weights = torch.softmax(scores.float(), dim=-1).to(v.dtype)
outputs.append((weights.unsqueeze(-1) * v_sel).sum(dim=-2))
return torch.cat(outputs, dim=2)
def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False):
q_lin, k_lin, v_lin = self.qkv(x).chunk(3, dim=-1)
v_new = self._reshape_v(v_lin)
if self.r > self.dk:
q = self._reshape_heads(q_lin) @ self._get_metric()
k_new = self._reshape_heads(k_lin)
else:
q = self._proj_qk(q_lin)
k_new = self._proj_qk(k_lin)
if kv_cache is None:
k, v = k_new, v_new
elif isinstance(kv_cache, KVBuffer):
if use_cache:
kv_cache.append(k_new, v_new)
k, v = kv_cache.view()
else:
k, v = k_new, v_new
else:
k_cached, v_cached = kv_cache
if use_cache:
k = torch.cat([k_cached, k_new], dim=2)
v = torch.cat([v_cached, v_new], dim=2)
else:
k, v = k_new, v_new
attn_mask = mask
if self.attn_backend != "sublinear" and _is_structured_attention_mask(attn_mask):
attn_mask = attn_mask.to_dense(device=q.device, dtype=q.dtype)
if self.attn_backend != "sublinear" and self.use_relpos and rel_bias_tokens is not None:
rel = alibi_bias(self.h, rel_bias_tokens)[:, :, -q.size(2):, :].to(device=q.device, dtype=q.dtype)
attn_mask = rel if attn_mask is None else attn_mask + rel
if self.attn_backend == "sdpa":
try:
z = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=0.0,
scale=1.0 / math.sqrt(self.dk),
)
except TypeError:
# Older torch lacks the scale kwarg. Rescale q so SDPA's default sqrt(r)
# denominator matches the historical AGILLM sqrt(d_k) denominator.
q_scaled = q * math.sqrt(q.size(-1) / self.dk)
z = F.scaled_dot_product_attention(q_scaled, k, v, attn_mask=attn_mask, dropout_p=0.0)
elif self.attn_backend == "sublinear":
z = self._sublinear_attention(q, k, v, attn_mask=attn_mask, rel_bias_tokens=rel_bias_tokens)
else:
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
if attn_mask is not None:
att = att + attn_mask
z = att.softmax(-1) @ v
z = z.transpose(1, 2).reshape(x.size(0), x.size(1), -1)
out = self.drop(self.proj(z))
if not use_cache:
return out
new_kv = kv_cache if isinstance(kv_cache, KVBuffer) else (k, v)
return out, new_kv
class MoEFFN(nn.Module):
def __init__(self, d: int, mlp_mult: int = 4, experts: int = 4, top_k: int = 1):
super().__init__()
self.d = int(d)
self.mlp_mult = max(1, int(mlp_mult))
self.num_experts = max(1, int(experts))
self.top_k = min(max(1, int(top_k)), self.num_experts)
hidden = self.mlp_mult * self.d
self.router = nn.Linear(self.d, self.num_experts, bias=False)
self.experts = nn.ModuleList([
nn.Sequential(nn.Linear(self.d, hidden), nn.ReLU(), nn.Linear(hidden, self.d))
for _ in range(self.num_experts)
])
def forward(self, x):
orig_shape = x.shape
flat = x.reshape(-1, orig_shape[-1])
scores = self.router(flat.float())
if self.top_k == 1:
probs = scores.softmax(dim=-1)
chosen = probs.argmax(dim=-1)
out = torch.zeros_like(flat)
for expert_id, expert in enumerate(self.experts):
mask = chosen == expert_id
if not bool(mask.any()):
continue
gate = probs[mask, expert_id].to(flat.dtype).clamp_min(1e-6)
# Keep the forward value equal to the selected expert while
# sending a straight-through gradient into the top-1 router.
gate_st = (gate / gate.detach()).unsqueeze(-1)
out[mask] = expert(flat[mask]) * gate_st
return out.reshape(orig_shape)
vals, idx = torch.topk(scores, k=self.top_k, dim=-1)
weights = vals.softmax(dim=-1).to(flat.dtype)
out = torch.zeros_like(flat)
for rank in range(self.top_k):
chosen = idx[:, rank]
weight = weights[:, rank].unsqueeze(-1)
for expert_id, expert in enumerate(self.experts):
rows = (chosen == expert_id).nonzero(as_tuple=False).flatten()
if rows.numel() == 0:
continue
out.index_add_(0, rows, expert(flat.index_select(0, rows)) * weight.index_select(0, rows))
return out.reshape(orig_shape)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
legacy = {
"0.weight": "0.weight",
"0.bias": "0.bias",
"2.weight": "2.weight",
"2.bias": "2.bias",
}
seeded = False
for expert_idx, expert in enumerate(self.experts):
expert_state = expert.state_dict()
for legacy_suffix, expert_suffix in legacy.items():
src_key = prefix + legacy_suffix
dst_key = prefix + f"experts.{expert_idx}." + expert_suffix
src = state_dict.get(src_key)
tgt = expert_state.get(expert_suffix)
if dst_key not in state_dict and torch.is_tensor(src) and torch.is_tensor(tgt) and tuple(src.shape) == tuple(tgt.shape):
state_dict[dst_key] = src
seeded = True
if seeded and prefix + "router.weight" not in state_dict:
state_dict[prefix + "router.weight"] = self.router.weight.detach().clone()
if seeded:
for suffix in legacy:
state_dict.pop(prefix + suffix, None)
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs,
)
class Block(nn.Module):
def __init__(
self,
d: int,
h: int,
r: int,
attn_backend: str = DEFAULT_ATTN_BACKEND,
sublinear_window: int = DEFAULT_SUBLINEAR_WINDOW,
sublinear_stride: int = DEFAULT_SUBLINEAR_STRIDE,
sublinear_max_anchors: int = DEFAULT_SUBLINEAR_MAX_ANCHORS,
sublinear_chunk: int = DEFAULT_SUBLINEAR_CHUNK,
sublinear_sinks: int = DEFAULT_SUBLINEAR_SINKS,
sublinear_recent_anchors: int = DEFAULT_SUBLINEAR_RECENT_ANCHORS,
sublinear_pooled_landmarks: bool = DEFAULT_SUBLINEAR_POOLED_LANDMARKS,
moe_ffn: bool = DEFAULT_MOE_FFN,
moe_experts: int = DEFAULT_MOE_EXPERTS,
moe_top_k: int = DEFAULT_MOE_TOP_K,
moe_mlp_mult: int = DEFAULT_MOE_MLP_MULT,
):
super().__init__()
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
self.mha = TuneableAttentionMHA(
d,
h,
r,
attn_backend=attn_backend,
sublinear_window=sublinear_window,
sublinear_stride=sublinear_stride,
sublinear_max_anchors=sublinear_max_anchors,
sublinear_chunk=sublinear_chunk,
sublinear_sinks=sublinear_sinks,
sublinear_recent_anchors=sublinear_recent_anchors,
sublinear_pooled_landmarks=sublinear_pooled_landmarks,
)
self.ff = (
MoEFFN(d, mlp_mult=moe_mlp_mult, experts=moe_experts, top_k=moe_top_k)
if moe_ffn else nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
)
def forward(self, x, mask, kv=None, use_cache=False, total_seq_len=None):
if use_cache:
y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=total_seq_len, kv_cache=kv, use_cache=True)
x = x + y + self.ff(self.ln2(x + y))
return x, new_kv
else:
n = x.size(1)
x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n)
return x + self.ff(self.ln2(x))
class Encoder(nn.Module):
def __init__(
self,
cfg,
tie_weights: bool = False,
attn_backend: str = DEFAULT_ATTN_BACKEND,
grad_checkpoint: bool = False,
sublinear_window: int = DEFAULT_SUBLINEAR_WINDOW,
sublinear_stride: int = DEFAULT_SUBLINEAR_STRIDE,
sublinear_max_anchors: int = DEFAULT_SUBLINEAR_MAX_ANCHORS,
sublinear_chunk: int = DEFAULT_SUBLINEAR_CHUNK,
sublinear_sinks: int = DEFAULT_SUBLINEAR_SINKS,
sublinear_recent_anchors: int = DEFAULT_SUBLINEAR_RECENT_ANCHORS,
sublinear_pooled_landmarks: bool = DEFAULT_SUBLINEAR_POOLED_LANDMARKS,
anchor_memory: bool = DEFAULT_ANCHOR_MEMORY,
anchor_stride: int = DEFAULT_ANCHOR_STRIDE,
anchor_max: int = DEFAULT_ANCHOR_MAX,
anchor_position: int = DEFAULT_ANCHOR_POSITION,
moe_ffn: Optional[bool] = None,
moe_experts: Optional[int] = None,
moe_top_k: Optional[int] = None,
moe_mlp_mult: Optional[int] = None,
):
super().__init__()
d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"]
if moe_ffn is None:
moe_ffn = bool(cfg.get("moe_ffn", DEFAULT_MOE_FFN))
if moe_experts is None:
moe_experts = int(cfg.get("moe_experts", DEFAULT_MOE_EXPERTS))
if moe_top_k is None:
moe_top_k = int(cfg.get("moe_top_k", DEFAULT_MOE_TOP_K))
if moe_mlp_mult is None:
moe_mlp_mult = int(cfg.get("moe_mlp_mult", DEFAULT_MOE_MLP_MULT))
moe_experts = max(1, int(moe_experts))
moe_top_k = min(max(1, int(moe_top_k)), moe_experts)
moe_mlp_mult = max(1, int(moe_mlp_mult))
self.emb = nn.Embedding(VOCAB, d)
self.blocks = nn.ModuleList([
Block(
d,
h,
r,
attn_backend=attn_backend,
sublinear_window=sublinear_window,
sublinear_stride=sublinear_stride,
sublinear_max_anchors=sublinear_max_anchors,
sublinear_chunk=sublinear_chunk,
sublinear_sinks=sublinear_sinks,
sublinear_recent_anchors=sublinear_recent_anchors,
sublinear_pooled_landmarks=sublinear_pooled_landmarks,
moe_ffn=bool(moe_ffn),
moe_experts=moe_experts,
moe_top_k=moe_top_k,
moe_mlp_mult=moe_mlp_mult,
)
for _ in range(l)
])
self.ln = nn.LayerNorm(d)
self.tie_weights = tie_weights
self.attn_backend = attn_backend
self.grad_checkpoint = grad_checkpoint
self.sublinear_window = sublinear_window
self.sublinear_stride = sublinear_stride
self.sublinear_max_anchors = sublinear_max_anchors
self.sublinear_chunk = sublinear_chunk
self.sublinear_sinks = sublinear_sinks
self.sublinear_recent_anchors = sublinear_recent_anchors
self.sublinear_pooled_landmarks = bool(sublinear_pooled_landmarks)
self.moe_ffn = bool(moe_ffn)
self.moe_experts = moe_experts
self.moe_top_k = moe_top_k
self.moe_mlp_mult = moe_mlp_mult
self.anchor_memory_enabled = bool(anchor_memory)
self.anchor_stride = int(anchor_stride)
self.anchor_max = int(anchor_max)
n_layers = int(cfg["layers"])
if int(anchor_position) < 0:
self.anchor_position = n_layers // 2
else:
self.anchor_position = min(int(anchor_position), n_layers - 1)
if self.anchor_memory_enabled:
am_cfg = AnchorMemoryConfig(
d_model=int(cfg["d"]),
heads=int(cfg["heads"]),
anchor_stride=self.anchor_stride,
max_anchors=self.anchor_max,
)
self.anchor = AnchorMemoryLayer(am_cfg)
else:
self.anchor = None
def forward(self, ids, mask, kv_caches=None, use_cache=False, total_seq_len=None):
x = self.emb(ids)
if not use_cache:
for i, blk in enumerate(self.blocks):
if self.grad_checkpoint and self.training:
x = torch_checkpoint.checkpoint(lambda y, block=blk: block(y, mask), x, use_reentrant=False)
else:
x = blk(x, mask)
if self.anchor is not None and i == self.anchor_position:
if self.grad_checkpoint and self.training:
x, _ = torch_checkpoint.checkpoint(self.anchor, x, use_reentrant=False)
else:
x, _ = self.anchor(x)
return self.ln(x)
new_kvs = []
for i, blk in enumerate(self.blocks):
kv = kv_caches[i] if kv_caches else None
x, kv_out = blk(x, mask, kv, use_cache=True, total_seq_len=total_seq_len)
new_kvs.append(kv_out)
if self.anchor is not None and i == self.anchor_position:
x, _ = self.anchor(x)
return self.ln(x), new_kvs
class ARHead(nn.Module):
def __init__(self, d, tie_weights: bool = False, embedding_weight: nn.Parameter = None):
super().__init__()
self.tie_weights = tie_weights
if tie_weights and embedding_weight is not None:
self.proj = nn.Linear(d, VOCAB, bias=False)
self.proj.weight = embedding_weight
else:
self.proj = nn.Linear(d, VOCAB)
def forward(self, h):
return self.proj(h)
class NATHead(nn.Module):
def __init__(self, d, tie_weights: bool = False, embedding_weight: nn.Parameter = None):
super().__init__()
self.tie_weights = tie_weights
if tie_weights and embedding_weight is not None:
self.proj = nn.Linear(d, VOCAB, bias=False)
self.proj.weight = embedding_weight
else:
self.proj = nn.Linear(d, VOCAB)
def forward(self, h):
return self.proj(h)
class SATHead(nn.Module):
def __init__(self, d, mode="var", tie_weights: bool = False, embedding_weight: nn.Parameter = None, mlp: bool = False):
super().__init__()
self.tie_weights = tie_weights
self.mlp = bool(mlp)
if self.mlp:
self.proj = nn.Sequential(
nn.Linear(d, d),
nn.GELU(),
nn.Linear(d, VOCAB),
)
elif tie_weights and embedding_weight is not None:
self.proj = nn.Linear(d, VOCAB, bias=False)
self.proj.weight = embedding_weight
else:
self.proj = nn.Linear(d, VOCAB)
self.gate = nn.Linear(d, 2) if mode == "var" else None
def forward(self, h_last):
return self.proj(h_last), (self.gate(h_last[:, 0]) if self.gate else None)
# ───────────────────────── Masks ─────────────────────────
def causal_mask(n, structured: bool = False):
if structured:
return StructuredAttentionMask("causal", q_len=n, k_len=n, query_base=0)
return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1)
def sat_mask(n, block=SAT_BLOCK, structured: bool = False):
if structured:
return StructuredAttentionMask("sat", q_len=n, k_len=n, query_base=0, block=block)
idx = torch.arange(n, device=DEV)
grp = idx.unsqueeze(0) // block
allow = (grp.T == grp) | (grp.T > grp)
return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
def sat_mask_cached(new_len: int, cached_len: int, block=SAT_BLOCK, structured: bool = False):
total_len = cached_len + new_len
if structured:
return StructuredAttentionMask("sat", q_len=new_len, k_len=total_len, query_base=cached_len, block=block)
q_idx = torch.arange(cached_len, total_len, device=DEV).unsqueeze(1)
k_idx = torch.arange(total_len, device=DEV).unsqueeze(0)
q_grp = q_idx // block
k_grp = k_idx // block
allow = q_grp >= k_grp
return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
# ───────────────────────── Checkpoint helpers ─────────────────────────
# ───────────────────────── Delta Checkpoints (weight-only, async) ─────────────────────────
_delta_lock = threading.Lock()
_delta_thread: Optional[threading.Thread] = None
def _sha256_file(path: pathlib.Path) -> str:
"""Compute SHA256 of a file for integrity verification."""
h = hashlib.sha256()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
h.update(chunk)
return h.hexdigest()
def _do_delta_save(tensors: dict, path: pathlib.Path, meta: dict):
"""Background worker: write weight-only checkpoint + checksum."""
try:
path.parent.mkdir(exist_ok=True, parents=True)
tmp = path.with_suffix(path.suffix + ".dtmp")
torch.save({"weights": tensors, **meta}, tmp, _use_new_zipfile_serialization=False)
digest = _sha256_file(tmp)
tmp.replace(path)
# Write sidecar checksum
path.with_suffix(".sha256").write_text(f"{digest} {path.name}\n")
print(f" [delta] saved {path.name} ({digest[:12]}...)")
except Exception as e:
print(f" [delta] FAILED {path.name}: {e}")
def _delete_delta_artifacts(path: pathlib.Path):
for sidecar in (
path,
path.with_suffix(".sha256"),
path.with_suffix(path.suffix + ".upload.sha256"),
path.with_suffix(path.suffix + ".dtmp"),
):
try:
if sidecar.exists():
sidecar.unlink()
except Exception:
pass
def _unwrap_compiled_module(module: nn.Module) -> nn.Module:
"""Return the original module when torch.compile wrapped it."""
return getattr(module, "_orig_mod", module)
def _checkpoint_state_dict(module: nn.Module) -> dict:
"""State dict with stable keys, even when module is torch.compile'd."""
return _unwrap_compiled_module(module).state_dict()
def _strip_orig_mod_prefix(state: dict) -> dict:
"""Accept older deltas accidentally saved from compiled modules."""
if not isinstance(state, dict):
return state
prefix = "_orig_mod."
if not any(isinstance(k, str) and k.startswith(prefix) for k in state):
return state
return {
(k[len(prefix):] if isinstance(k, str) and k.startswith(prefix) else k): v
for k, v in state.items()
}
def _cat_legacy_weight_blocks(blocks: list) -> Optional[torch.Tensor]:
if not blocks or not all(torch.is_tensor(t) for t in blocks):
return None
first = blocks[0]
tail_shape = tuple(first.shape[1:])
if any(t.dtype != first.dtype or t.device != first.device for t in blocks):
return None
if any(t.ndim != first.ndim or tuple(t.shape[1:]) != tail_shape for t in blocks):
return None
return torch.cat(blocks, dim=0).contiguous()
def _fuse_qkv_in_state_dict(sd: dict) -> dict:
"""Fold legacy q/k/v.weight triples into qkv.weight before loading/filtering."""
if not isinstance(sd, dict):
return sd
prefixes = set()
for key in list(sd.keys()):
for suffix in (".q.weight", ".k.weight", ".v.weight"):
if isinstance(key, str) and key.endswith(suffix):
prefixes.add(key[: -len(suffix)])
for prefix in prefixes:
qk, kk, vk = prefix + ".q.weight", prefix + ".k.weight", prefix + ".v.weight"
fk = prefix + ".qkv.weight"
if qk in sd and kk in sd and vk in sd and fk not in sd:
fused = _cat_legacy_weight_blocks([sd[qk], sd[kk], sd[vk]])
if fused is not None:
sd[fk] = fused
sd.pop(qk)
sd.pop(kk)
sd.pop(vk)
return sd
def _expand_dense_ffn_to_moe_state_dict(sd: dict, target_sd: dict) -> dict:
if not isinstance(sd, dict) or not isinstance(target_sd, dict):
return sd
out = dict(sd)
seeded_prefixes: set[str] = set()
for target_key, target in target_sd.items():
if not isinstance(target_key, str) or ".ff.experts." not in target_key:
continue
match = re.match(r"(blocks\.\d+\.ff\.)experts\.\d+\.(0|2)\.(weight|bias)$", target_key)
if not match:
continue
prefix = match.group(1)
legacy_key = f"{prefix}{match.group(2)}.{match.group(3)}"
src = out.get(legacy_key)
if target_key not in out and torch.is_tensor(src) and torch.is_tensor(target) and tuple(src.shape) == tuple(target.shape):
out[target_key] = src
seeded_prefixes.add(prefix)
for prefix in seeded_prefixes:
router_key = prefix + "router.weight"
router_target = target_sd.get(router_key)
if router_key not in out and torch.is_tensor(router_target):
out[router_key] = router_target.detach().clone()
for legacy_suffix in ("0.weight", "0.bias", "2.weight", "2.bias"):
out.pop(prefix + legacy_suffix, None)
return out
def _prepare_core_state_dict_for_load(core: nn.Module, sd: dict) -> dict:
sd = _strip_orig_mod_prefix(sd)
sd = _fuse_qkv_in_state_dict(dict(sd)) if isinstance(sd, dict) else sd
if isinstance(sd, dict):
sd = _expand_dense_ffn_to_moe_state_dict(sd, core.state_dict())
return sd
def _split_qkv_in_state_dict_for_test(sd: dict) -> dict:
out = dict(sd)
for key in list(out.keys()):
if not isinstance(key, str) or not key.endswith(".qkv.weight"):
continue
base = key[: -len(".qkv.weight")]
q, k, v = out.pop(key).chunk(3, dim=0)
out[base + ".q.weight"] = q.clone()
out[base + ".k.weight"] = k.clone()
out[base + ".v.weight"] = v.clone()
return out
def _clone_opt_value(value):
if torch.is_tensor(value):
return value.detach().clone()
return copy.deepcopy(value)
def _optimizer_param_name_lookup(core, ar_h, sat_h, nat_h=None) -> dict[int, str]:
out = {}
for prefix, module in (("core", core), ("ar", ar_h), ("sat", sat_h), ("nat", nat_h)):
if module is None:
continue
for name, param in module.named_parameters():
out.setdefault(id(param), f"{prefix}.{name}")
return out
def _optimizer_group_param_names(opt, core, ar_h, sat_h, nat_h=None) -> List[List[str]]:
lookup = _optimizer_param_name_lookup(core, ar_h, sat_h, nat_h)
return [
[lookup.get(id(param), f"<unknown:{id(param)}>") for param in group["params"]]
for group in opt.param_groups
]
def _legacy_names_for_current_param(name: str) -> List[str]:
if name.endswith(".qkv.weight"):
base = name[: -len(".qkv.weight")]
return [base + ".q.weight", base + ".k.weight", base + ".v.weight"]
return [name]
def _fuse_legacy_optimizer_param_state(states: List[dict]) -> Optional[dict]:
if len(states) < 2 or any(not isinstance(state, dict) for state in states):
return None
common = set(states[0])
for state in states[1:]:
common &= set(state)
out = {}
for key in common:
vals = [state[key] for state in states]
if all(torch.is_tensor(v) for v in vals):
shape = vals[0].shape
if vals[0].ndim > 0 and all(v.shape == shape for v in vals[1:]):
out[key] = torch.cat([v.detach().clone() for v in vals], dim=0).contiguous()
else:
out[key] = vals[0].detach().clone()
else:
out[key] = copy.deepcopy(vals[0])
return out
def _fuse_legacy_qkv_optimizer_state(opt_state: dict, opt, core, ar_h, sat_h, nat_h=None) -> Optional[dict]:
"""Remap pre-QKV-fusion AdamW state to the current fused parameter layout."""
if not isinstance(opt_state, dict) or "state" not in opt_state or "param_groups" not in opt_state:
return None
current_sd = opt.state_dict()
current_names = _optimizer_group_param_names(opt, core, ar_h, sat_h, nat_h)
legacy_names = [
[legacy for name in group_names for legacy in _legacy_names_for_current_param(name)]
for group_names in current_names
]
if len(legacy_names) != len(opt_state.get("param_groups", [])):
return None
legacy_name_to_pid = {}
for group_idx, names in enumerate(legacy_names):
old_params = list(opt_state["param_groups"][group_idx].get("params", []))
if len(names) != len(old_params):
return None
for name, pid in zip(names, old_params):
legacy_name_to_pid[name] = pid
new_groups = []
for group_idx, current_group in enumerate(current_sd["param_groups"]):
new_group = copy.deepcopy(opt_state["param_groups"][group_idx])
new_group["params"] = list(current_group["params"])
if "param_names" in new_group:
new_group["param_names"] = list(current_names[group_idx])
new_groups.append(new_group)
old_states = opt_state.get("state", {})
new_states = {}
for group_names, current_group in zip(current_names, current_sd["param_groups"]):
for name, new_pid in zip(group_names, current_group["params"]):
legacy_set = _legacy_names_for_current_param(name)
if len(legacy_set) > 1:
old_pids = [legacy_name_to_pid.get(legacy) for legacy in legacy_set]
if all(pid in old_states for pid in old_pids):
fused = _fuse_legacy_optimizer_param_state([old_states[pid] for pid in old_pids])
if fused is not None:
new_states[new_pid] = fused
continue
old_pid = legacy_name_to_pid.get(name)
if old_pid in old_states:
new_states[new_pid] = {key: _clone_opt_value(value) for key, value in old_states[old_pid].items()}
return {"state": new_states, "param_groups": new_groups}
def save_delta(core, ar_h, sat_h, nat_h, step: int, seen_tok: int, save_dir: pathlib.Path, phase_name: str):
"""Save weight-only delta in background thread. Non-blocking."""
global _delta_thread
# Wait for any previous delta write to finish
if _delta_thread is not None and _delta_thread.is_alive():
_delta_thread.join(timeout=60)
# Snapshot weights to CPU (detach from GPU graph)
with _delta_lock:
tensors = {
"core": {k: v.detach().cpu() for k, v in _checkpoint_state_dict(core).items()},
"ar": {k: v.detach().cpu() for k, v in _checkpoint_state_dict(ar_h).items()},
"sat": {k: v.detach().cpu() for k, v in _checkpoint_state_dict(sat_h).items()},
}
if nat_h is not None:
tensors["nat"] = {k: v.detach().cpu() for k, v in _checkpoint_state_dict(nat_h).items()}
meta = {"step": step, "seen_tok": seen_tok, "wall_time": time.time(), "delta": True}
path = save_dir / f"{phase_name}_delta_step{step:08d}.pt"
_delta_thread = threading.Thread(target=_do_delta_save, args=(tensors, path, meta), daemon=True)
_delta_thread.start()
def _prune_delta_files_to_count(save_dir: pathlib.Path, phase_name: str, keep_count: int):
"""Keep only the newest keep_count complete delta files."""
try:
pattern = f"{phase_name}_delta_step*.pt"
deltas = sorted(
[p for p in save_dir.glob(pattern) if p.stat().st_size > 0],
key=lambda p: p.stat().st_mtime
)
excess = len(deltas) - max(0, keep_count)
if excess > 0:
for p in deltas[:excess]:
_delete_delta_artifacts(p)
print(f" [delta-prune] deleted {p.name}")
except Exception as e:
print(f" [delta-prune] error: {e}")
def _prune_deltas(save_dir: pathlib.Path, phase_name: str, max_deltas: int):
"""Keep only the most recent max_deltas delta files."""
if max_deltas is None or max_deltas <= 0:
return
_prune_delta_files_to_count(save_dir, phase_name, max_deltas)
def _load_module_state_compatible(module: nn.Module, state: dict, label: str = "module") -> int:
"""Load matching tensors only; skip obsolete untied vocab matrices for tied heads."""
if not isinstance(state, dict):
return 0
state = _strip_orig_mod_prefix(state)
tgt_sd = module.state_dict()
tied = bool(getattr(module, "tie_weights", False))
filt = {}
skipped = []
for k, v in state.items():
if tied and k == "proj.weight":
skipped.append(k)
continue
if k in tgt_sd and hasattr(v, "shape") and v.shape == tgt_sd[k].shape:
filt[k] = v
else:
skipped.append(k)
if filt:
module.load_state_dict(filt, strict=False)
if tied and skipped:
print(f"[ckpt] {label}: tied head active; skipped old untied tensors: {', '.join(skipped[:4])}{'...' if len(skipped)>4 else ''}")
return len(filt)
def load_delta(path: pathlib.Path, core, ar_h, sat_h, nat_h=None):
"""Load weight-only delta. Returns (step, seen_tok) or raises."""
# Verify checksum if sidecar exists
sha_path = path.with_suffix(".sha256")
if sha_path.exists():
expected = sha_path.read_text().split()[0]
actual = _sha256_file(path)
if expected != actual:
raise ValueError(f"Checksum mismatch for {path.name}: expected {expected[:12]}... got {actual[:12]}...")
print(f" [delta] checksum OK for {path.name}")
ck = torch.load(path, map_location="cpu", weights_only=False)
if not ck.get("delta"):
raise ValueError(f"{path.name} is not a delta checkpoint")
core.load_state_dict(_prepare_core_state_dict_for_load(core, ck["weights"]["core"]))
_load_module_state_compatible(ar_h, ck["weights"].get("ar", {}), "ar")
_load_module_state_compatible(sat_h, ck["weights"].get("sat", {}), "sat")
if nat_h is not None:
nat_sd = ck["weights"].get("nat")
if nat_sd is not None:
_load_module_state_compatible(nat_h, nat_sd, "nat")
else:
print("[nat] Delta has no NAT head; keeping fresh NAT initialization")
return ck.get("step", 0), ck.get("seen_tok", 0)
def _flush_delta():
"""Wait for any in-flight delta save to complete."""
global _delta_thread
if _delta_thread is not None and _delta_thread.is_alive():
print(" [delta] flushing in-flight write...")
_delta_thread.join(timeout=120)
def save_ckpt(path: pathlib.Path, core, ar_h, sat_h, nat_h, opt, scaler, meta):
path.parent.mkdir(exist_ok=True, parents=True)
tmp = path.with_suffix(path.suffix + ".tmp")
state = {
"core": _checkpoint_state_dict(core), "ar": _checkpoint_state_dict(ar_h), "sat": _checkpoint_state_dict(sat_h),
"opt": opt.state_dict(), "scaler": scaler.state_dict(),
"cfg": meta.get("cfg"), "tokenizer_id": TOKENIZER_ID,
"tokenizer_json": tok.backend_tokenizer.to_str(),
"transformers_version": __import__("transformers").__version__,
"tokenizers_version": __import__("tokenizers").__version__,
"tie_weights": meta.get("tie_weights", False),
**{k: v for k, v in meta.items() if k not in ("cfg", "tie_weights")}
}
if nat_h is not None:
state["nat"] = _checkpoint_state_dict(nat_h)
torch.save(state, tmp, _use_new_zipfile_serialization=False)
tmp.replace(path)
(path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]}))
print(f"\n✓ saved checkpoint {path.name}")
def load_ckpt(path, core, ar_h, sat_h, opt, scaler, nat_h=None):
p = _resolve_ckpt(path) or path
ck = _try_load(p, map_location="cpu")
if ck is None: raise FileNotFoundError(f"No valid checkpoint at {p}")
core.load_state_dict(_prepare_core_state_dict_for_load(core, ck["core"]))
_load_module_state_compatible(ar_h, ck.get("ar", {}), "ar")
_load_module_state_compatible(sat_h, ck.get("sat", {}), "sat")
if nat_h is not None:
if "nat" in ck:
_load_module_state_compatible(nat_h, ck["nat"], "nat")
else:
print("[nat] Checkpoint has no NAT head; keeping fresh NAT initialization")
try:
opt.load_state_dict(ck["opt"])
except Exception as exc:
fused_opt = _fuse_legacy_qkv_optimizer_state(ck.get("opt"), opt, core, ar_h, sat_h, nat_h)
if fused_opt is not None:
try:
opt.load_state_dict(fused_opt)
print("[ckpt] Converted legacy q/k/v optimizer state to fused qkv layout")
except Exception as exc2:
print(f"[ckpt] WARNING: optimizer state incompatible; resetting optimizer ({type(exc).__name__}: {exc}; qkv remap failed: {type(exc2).__name__}: {exc2})")
else:
print(f"[ckpt] WARNING: optimizer state incompatible; resetting optimizer ({type(exc).__name__}: {exc})")
try:
scaler.load_state_dict(ck["scaler"])
except Exception as exc:
print(f"[ckpt] WARNING: scaler state incompatible; resetting scaler ({type(exc).__name__}: {exc})")
# Restore tokenizer from checkpoint if available
if "tokenizer_json" in ck:
try:
from tokenizers import Tokenizer as _Tokenizer
tok.backend_tokenizer = _Tokenizer.from_str(ck["tokenizer_json"])
print("[tokenizer] Restored from checkpoint")
except Exception as e:
print(f"[tokenizer] WARNING: could not restore from checkpoint: {e}")
# Warn if transformers version changed since checkpoint was saved
if "transformers_version" in ck:
import transformers as _tf
if ck["transformers_version"] != _tf.__version__:
print(f"[tokenizer] WARNING: checkpoint saved with transformers={ck['transformers_version']}, now running {_tf.__version__}")
return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time())
def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None):
p = _resolve_ckpt(path) or path
if not p.exists(): return 0
ck = _try_load(p, map_location="cpu")
if ck is None: return 0
sd = ck.get(key, ck) if key else ck
if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"]
if isinstance(tgt, Encoder) or key == "core":
sd = _prepare_core_state_dict_for_load(tgt, sd)
else:
sd = _strip_orig_mod_prefix(sd)
sd = _fuse_qkv_in_state_dict(dict(sd)) if isinstance(sd, dict) else sd
if not isinstance(sd, dict):
return 0
tgt_sd = tgt.state_dict()
filt = {k: v for k, v in sd.items() if k in tgt_sd and hasattr(v, "shape") and v.shape == tgt_sd[k].shape}
if filt: tgt.load_state_dict(filt, strict=False)
return len(filt)
def infer_cfg_from_ckpt(path: pathlib.Path):
p = _resolve_ckpt(path) or path
if not p.exists(): return None
sd = _try_load(p, map_location="cpu")
if sd is None: return None
if "cfg" in sd: return dict(sd["cfg"])
return None
# ───────────────────────── Training Logic ─────────────────────────
def _load_infer_head_state(module: nn.Module, state: dict, name: str):
"""Load inference heads across small checkpoint/schema drifts.
Some older AGILLM-4 full checkpoints were saved before the current SAT/NAT
head bias fields existed. For inference, preserve the old behavior by
explicitly zero-filling missing bias tensors, while still failing on missing
non-bias weights or shape mismatches.
"""
if not isinstance(state, dict):
module.load_state_dict(state)
return
module_state = module.state_dict()
patched = dict(state)
zero_filled = []
shape_mismatch = []
for key, target in module_state.items():
if key not in patched and key.endswith('.bias') and torch.is_tensor(target):
patched[key] = torch.zeros_like(target)
zero_filled.append(key)
for key, value in list(patched.items()):
target = module_state.get(key)
if target is None or not torch.is_tensor(value) or not torch.is_tensor(target):
continue
if tuple(value.shape) != tuple(target.shape):
shape_mismatch.append(f"{key}: ckpt={tuple(value.shape)} model={tuple(target.shape)}")
patched.pop(key)
if shape_mismatch:
raise RuntimeError(f"{name} checkpoint shape mismatch: " + "; ".join(shape_mismatch[:6]))
loaded = module.load_state_dict(patched, strict=False)
missing = [key for key in loaded.missing_keys if key not in zero_filled]
if missing:
raise RuntimeError(f"{name} checkpoint missing required keys: " + ", ".join(missing[:12]))
notes = []
if zero_filled:
notes.append("zero-filled " + ", ".join(zero_filled[:6]))
if loaded.unexpected_keys:
notes.append("ignored unexpected " + ", ".join(loaded.unexpected_keys[:6]))
if notes:
print(f"[infer-compat] {name}: " + "; ".join(notes), flush=True)
def _sat_head_mlp_from_state(sd: dict) -> bool:
sat_sd = sd.get("sat", {})
if sd.get("delta") and "weights" in sd:
sat_sd = sd["weights"].get("sat", sat_sd)
return any(str(key).startswith("proj.2.") for key in sat_sd)
def _parse_grow_plan(s: str) -> List[int]:
return sorted(set([int(x.strip()) for x in s.split(",") if x.strip() and int(x.strip()) >= 128]))
def _count_enabled_params(*modules) -> int:
seen_data_ptrs = set()
total = 0
for m in modules:
if m is None:
continue
for p in m.parameters():
if p.data_ptr() not in seen_data_ptrs:
seen_data_ptrs.add(p.data_ptr())
total += p.numel()
return total
def _target_token_ratio(args) -> float:
if getattr(args, "token_param_ratio", 0.0) and args.token_param_ratio > 0:
return float(args.token_param_ratio)
if str(getattr(args, "preset", "")).startswith("agillm4_"):
return AGILLM4_TOKEN_PARAM_RATIO
return 51.2 if args.chilla_max_double else 25.0
def _phase_freeze(core: nn.Module, *, freeze_core: bool, unfreeze_ln: bool, train_emb: bool):
for p in core.parameters(): p.requires_grad = not freeze_core
if freeze_core:
if unfreeze_ln:
for blk in core.blocks:
for p in blk.ln1.parameters(): p.requires_grad = True
for p in blk.ln2.parameters(): p.requires_grad = True
for p in core.ln.parameters(): p.requires_grad = True
if train_emb:
for p in core.emb.parameters(): p.requires_grad = True
def _side_update_unique_path(directory: pathlib.Path, name: str) -> pathlib.Path:
directory.mkdir(parents=True, exist_ok=True)
dest = directory / name
if not dest.exists():
return dest
stem, suffix = dest.stem, dest.suffix
stamp = time.strftime("%Y%m%d-%H%M%S", time.gmtime())
for idx in range(1000):
candidate = directory / f"{stem}.{stamp}.{idx}{suffix}"
if not candidate.exists():
return candidate
return directory / f"{stem}.{stamp}.{os.getpid()}{suffix}"
def _side_update_move(path: pathlib.Path, directory: pathlib.Path) -> pathlib.Path:
dest = _side_update_unique_path(directory, path.name)
try:
path.replace(dest)
except OSError:
import shutil
shutil.move(str(path), str(dest))
return dest
def _apply_async_side_updates(core: nn.Module, cfg: dict, args, step: int) -> list[dict]:
update_dir_s = str(getattr(args, "async_update_dir", "") or "").strip()
alpha = float(getattr(args, "async_update_alpha", 1.0) or 0.0)
if not update_dir_s or alpha <= 0.0:
return []
update_dir = pathlib.Path(update_dir_s)
if not update_dir.exists():
return []
max_updates = max(1, int(getattr(args, "async_update_max_per_check", 1) or 1))
max_age = float(getattr(args, "async_update_max_age_sec", 0.0) or 0.0)
accepted_dir = pathlib.Path(getattr(args, "async_update_accepted_dir", "") or (update_dir.parent / "accepted"))
rejected_dir = pathlib.Path(getattr(args, "async_update_rejected_dir", "") or (update_dir.parent / "rejected"))
param_map = dict(core.named_parameters())
buffer_map = dict(core.named_buffers())
now = time.time()
applied: list[dict] = []
candidates = sorted(
[p for p in update_dir.glob("*.pt") if p.is_file() and not p.name.endswith(".tmp")],
key=lambda p: p.stat().st_mtime,
)
for path in candidates[:max_updates]:
reject_reason = ""
try:
if max_age > 0 and now - path.stat().st_mtime > max_age:
reject_reason = f"stale update older than {max_age:g}s"
raise ValueError(reject_reason)
upd = torch.load(path, map_location="cpu", weights_only=False)
kind = upd.get("kind")
if kind not in {"agillm35_dblock_slice_update", "agillm4_dblock_slice_update", "agillm41_dblock_slice_update"}:
raise ValueError(f"bad update kind {kind!r}")
if dict(upd.get("cfg", {})) != dict(cfg):
raise ValueError("cfg mismatch")
block_state = upd.get("block_state")
if not isinstance(block_state, dict) or not block_state:
raise ValueError("missing block_state")
changed = 0
with torch.no_grad():
for key, value in block_state.items():
target = param_map.get(key)
if target is None:
target = buffer_map.get(key)
if target is None:
raise KeyError(f"unknown core key {key}")
if tuple(value.shape) != tuple(target.shape):
raise ValueError(f"{key} shape mismatch update={tuple(value.shape)} target={tuple(target.shape)}")
src = value.to(device=target.device, dtype=target.dtype, non_blocking=True)
if alpha >= 1.0:
target.copy_(src)
else:
target.lerp_(src, alpha)
changed += 1
del src
dest = _side_update_move(path, accepted_dir)
rec = {
"path": str(dest),
"worker_id": upd.get("worker_id"),
"block_id": upd.get("block_id"),
"layers": upd.get("layers"),
"tokens": int(upd.get("tokens") or 0),
"tok_per_sec": float(upd.get("tok_per_sec") or 0.0),
"alpha": alpha,
"keys": changed,
}
applied.append(rec)
print(json.dumps({"event": "async_side_update_applied", "step": step, **rec}), flush=True)
except Exception as exc:
try:
dest = _side_update_move(path, rejected_dir)
except Exception:
dest = path
print(
json.dumps(
{
"event": "async_side_update_rejected",
"step": step,
"path": str(dest),
"error": reject_reason or str(exc),
}
),
flush=True,
)
return applied
def _optimizer_param_groups(core, ar_h, sat_h, lr_core: float, lr_head: float, nat_h=None):
# Shared/tied vocab projections must appear in only one optimizer group.
# VRAM-first AGILLM-4 uses one embedding/projection tensor for AR/SAT/NAT.
seen: set[int] = set()
groups = []
def add(params, lr):
unique = []
for p in params:
if not p.requires_grad:
continue
key = id(p)
if key in seen:
continue
seen.add(key)
unique.append(p)
if unique:
groups.append({"params": unique, "lr": lr})
add(core.parameters(), lr_core)
add(ar_h.parameters(), lr_head)
add(sat_h.parameters(), lr_head)
if nat_h is not None:
add(nat_h.parameters(), lr_head)
return groups
def make_optimizer(args, core, ar_h, sat_h, lr_core: float, lr_head: float, nat_h=None):
groups = _optimizer_param_groups(core, ar_h, sat_h, lr_core, lr_head, nat_h)
opt_name = getattr(args, "optimizer", "adamw")
if opt_name == "adamw":
return torch.optim.AdamW(groups)
if opt_name in {"adamw8bit", "paged_adamw8bit"}:
try:
import bitsandbytes as bnb
except Exception as exc:
raise RuntimeError(
f"--optimizer {opt_name} requires bitsandbytes. Install it in the training env first."
) from exc
if opt_name == "paged_adamw8bit":
return bnb.optim.PagedAdamW8bit(groups)
return bnb.optim.AdamW8bit(groups)
raise ValueError(f"unknown optimizer: {opt_name}")
def _nat_ids_for_training(ids: torch.Tensor, max_tokens: int) -> torch.Tensor:
if max_tokens and max_tokens > 0 and ids.size(1) > max_tokens:
return ids[:, -max_tokens:]
return ids
def _train_phase(
args, phase_name: str,
core, ar_h, sat_h, nat_h, opt, scaler,
start_step, seen_tok, resume_wall_time,
cfg, source, steps, block_size, batch_size,
chat_cfg: dict,
max_ckpts: int,
target_tokens_override: Optional[int] = None,
tie_weights: bool = False,
streaming: bool = True
):
BLOCK = block_size
BATCH = batch_size
if target_tokens_override is not None:
target_tokens = target_tokens_override
else:
ratio = _target_token_ratio(args)
param_count = _count_enabled_params(core, ar_h, sat_h, nat_h)
target_tokens = int(ratio * param_count)
print(f"[{phase_name}] token_param_ratio={ratio:g} param_count={param_count:,} target_tokens={target_tokens:,}")
if steps:
phase_target_tokens = steps * BLOCK * BATCH
total_tokens_needed = seen_tok + phase_target_tokens
else:
total_tokens_needed = target_tokens
if total_tokens_needed <= seen_tok:
print(f"[{phase_name}] target {total_tokens_needed} already reached.")
return start_step, seen_tok, resume_wall_time
stream = token_stream(
source, total_tokens_needed, seed=42,
chat=chat_cfg.get("chat", False),
chat_messages_key=chat_cfg.get("key", "messages"),
sft_add_generation_prompt=chat_cfg.get("gen_prompt", False),
dataset_field_text=chat_cfg.get("text_field", "text"),
streaming=streaming
)
ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1)
ce_gate = nn.CrossEntropyLoss()
ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True)
pbar = SafeProgress(total=total_tokens_needed, initial=seen_tok, unit="tok")
grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else []
buf: list[int] = []
batch_accum: list[list[int]] = []
step = start_step
steps_since_last_grow = 0
oom_retries = 0
MAX_OOM_RETRIES = 2
now_wall = time.time()
last_save_mono = time.monotonic() - (now_wall - (resume_wall_time or now_wall))
last_delta_step = start_step
last_heartbeat_mono = time.monotonic()
print(f"[{phase_name}] Starting. Goal: {total_tokens_needed:,} tokens. Batch={BATCH}, Block={BLOCK}")
print(
f"[{phase_name}] AR_ONLY={args.ar_only}, SAT_EVERY={args.sat_every}, "
f"NAT_EVERY={args.nat_every}, TIE_WEIGHTS={tie_weights}, STREAMING={streaming}"
)
_flush_flag = [False]
def _on_flush_signal(signum, frame):
_flush_flag[0] = True
print(f"\n[{phase_name}] flush signal received; will checkpoint at next step")
try:
signal.signal(signal.SIGUSR1, _on_flush_signal)
print(f"[{phase_name}] on-demand flush ready: kill -USR1 {os.getpid()} or touch {pathlib.Path(args.save_dir) / 'FLUSH_NOW'}")
except (ValueError, OSError):
pass
_DBS = _dblock_init(core, args) if getattr(args,'dblock',False) else None
if DEV.type == "cuda":
try:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(
f"[vram] training-start cache cleared: "
f"alloc={torch.cuda.memory_allocated() / (1024**3):.2f}GB "
f"reserved={torch.cuda.memory_reserved() / (1024**3):.2f}GB "
f"structured_masks={use_structured_masks(args)}",
flush=True,
)
except Exception:
pass
while seen_tok < total_tokens_needed:
_profile_batch = _DBS is not None and int(getattr(args, "profile_steps", 0) or 0) > 0 and int(_DBS.get("profile_n", 0)) < int(getattr(args, "profile_steps", 0) or 0)
_data_t = time.perf_counter() if _profile_batch else None
try:
while len(buf) < BLOCK:
buf.append(next(stream))
except StopIteration:
break
if _profile_batch:
try:
import dblocks_train as _db_prof
_db_prof._profile_add(_DBS, "data_stream", time.perf_counter() - _data_t)
except Exception:
pass
seq = buf[:BLOCK]
buf = buf[BLOCK:]
batch_accum.append(seq)
if len(batch_accum) < BATCH:
continue
_tensor_t = time.perf_counter() if _profile_batch else None
ids = torch.tensor(batch_accum, device=DEV)
if _profile_batch:
if DEV.type == "cuda":
try:
torch.cuda.synchronize()
except Exception:
pass
try:
import dblocks_train as _db_prof
_db_prof._profile_add(_DBS, "tensor", time.perf_counter() - _tensor_t)
except Exception:
pass
batch_accum = []
tgt_ar = ids.clone()
try:
if getattr(args, "dblock", False):
loss_value = _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, _DBS)
else:
with amp(args.amp):
h_ar = core(ids, causal_mask(ids.size(1), structured=use_structured_masks(args)))
logits_ar = ar_h(h_ar)[:, :-1]
loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1))
loss_value = float(loss_ar.detach().item())
scaler.scale(loss_ar).backward()
del h_ar, logits_ar, loss_ar
do_sat = (not args.ar_only) and (args.sat_every <= 1 or ((step + 1) % args.sat_every == 0))
if do_sat:
# Same AR+SAT objective as a summed loss, but sequential backward keeps
# only one core-forward activation graph live at a time on 24GB cards.
with amp(args.amp):
h_sat = core(ids, sat_mask(ids.size(1), structured=use_structured_masks(args)))
logits_sat, gate = sat_h(h_sat[:, -SAT_BLOCK:])
tgt_sat = ids[:, 1:SAT_BLOCK+1]
loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1))
if gate is not None:
loss_sat += EMIT_LAMBDA * ce_gate(gate, torch.ones(ids.size(0), device=DEV, dtype=torch.long))
loss_value += float(loss_sat.detach().item())
scaler.scale(loss_sat).backward()
del h_sat, logits_sat, loss_sat
do_nat = (
nat_h is not None
and (not args.ar_only)
and args.nat_every > 0
and (args.nat_every <= 1 or ((step + 1) % args.nat_every == 0))
)
if do_nat:
nat_ids = _nat_ids_for_training(ids, args.nat_max_tokens)
with amp(args.amp):
# Mask-predict (CMLM) objective: corrupt a fraction of positions
# with BLANK and reconstruct them from surrounding context. The
# old CTC objective fed the clean target as input, so the head
# only learned to copy and collapsed at inference on all-BLANK
# input. This conditions on real context and cannot collapse.
nat_in = nat_ids.clone()
ratio = min(max(float(args.nat_mask_ratio), 0.05), 0.95)
mask = torch.rand(nat_in.shape, device=nat_in.device) < ratio
if not bool(mask.any()):
mask[..., -1] = True
nat_in[mask] = BLANK
h_nat = core(nat_in, None)
logits_nat = nat_h(h_nat)
loss_nat = F.cross_entropy(logits_nat[mask].float(), nat_ids[mask])
loss_nat = float(args.nat_loss_weight) * loss_nat
loss_value += float(loss_nat.detach().item())
scaler.scale(loss_nat).backward()
del nat_ids, nat_in, mask, h_nat, logits_nat, loss_nat
scaler.unscale_(opt)
nn.utils.clip_grad_norm_([p for group in opt.param_groups for p in group["params"]], 1.0)
scaler.step(opt)
scaler.update()
opt.zero_grad(set_to_none=True)
except RuntimeError as e:
msg = str(e).lower()
if "out of memory" in msg or "cuda error" in msg:
batch_accum = []
opt.zero_grad(set_to_none=True)
scaler = GradScaler(enabled=(args.amp and _needs_grad_scaler()))
if DEV.type == "cuda":
torch.cuda.empty_cache()
torch.cuda.synchronize()
oom_retries += 1
if oom_retries <= MAX_OOM_RETRIES:
print(f"\n[{phase_name} OOM] Retry {oom_retries}/{MAX_OOM_RETRIES} at Batch={BATCH}, clearing VRAM...")
time.sleep(2)
continue
oom_retries = 0
if BATCH > 1:
print(f"\n[{phase_name} OOM] Reducing Batch: {BATCH} -> {BATCH - 1} (after {MAX_OOM_RETRIES} retries)")
BATCH -= 1
time.sleep(2)
else:
new_block = max(128, int(BLOCK * 0.8))
new_block = max(128, (new_block // 128) * 128)
if new_block >= BLOCK:
new_block = max(128, BLOCK - 128)
print(f"\n[{phase_name} OOM] Reducing Block: {BLOCK} -> {new_block}")
BLOCK = new_block
time.sleep(2)
steps_since_last_grow = 0
continue
raise
step += 1
# Periodic tokenizer spot-check: verify training data has spaces
if step % 1000 == 0:
try:
sample_text = tok.decode(ids[0][:50].tolist(), skip_special_tokens=True)
if len(sample_text) > 20 and " " not in sample_text:
print(f"\n[tokenizer] ALERT step {step}: decoded batch has NO SPACES!")
print(f" Sample: {repr(sample_text[:80])}")
print(" Check transformers version!")
except Exception:
pass
oom_retries = 0
toks_processed = BLOCK * BATCH
seen_tok += toks_processed
pbar.set_postfix(loss=f"{loss_value:.3f}", B=BATCH, L=BLOCK)
pbar.update(toks_processed)
async_every = int(getattr(args, "async_update_every_steps", 0) or 0)
if async_every > 0 and (step % async_every) == 0:
_apply_async_side_updates(core, cfg, args, step)
empty_cache_every = int(getattr(args, "empty_cache_every_steps", 0) or 0)
if DEV.type == "cuda" and empty_cache_every > 0 and (step % empty_cache_every) == 0:
try:
torch.cuda.empty_cache()
except Exception:
pass
heartbeat_every = int(getattr(args, "heartbeat_every_sec", 300) or 0)
now_mono = time.monotonic()
if heartbeat_every > 0 and now_mono - last_heartbeat_mono >= heartbeat_every:
mem = ""
if DEV.type == "cuda":
try:
mem = (
f" gpu_alloc={torch.cuda.memory_allocated() / (1024**3):.2f}GB"
f" gpu_reserved={torch.cuda.memory_reserved() / (1024**3):.2f}GB"
f" gpu_peak={torch.cuda.max_memory_allocated() / (1024**3):.2f}GB"
)
except Exception:
mem = ""
print(
f"[heartbeat] phase={phase_name} pid={os.getpid()} step={step} "
f"seen_tok={seen_tok} loss={loss_value:.3f} B={BATCH} L={BLOCK} "
f"dblock={bool(getattr(args, 'dblock', False))} structured_masks={use_structured_masks(args)}{mem}",
flush=True,
)
last_heartbeat_mono = now_mono
_flush_sentinel = pathlib.Path(args.save_dir) / "FLUSH_NOW"
if _flush_flag[0] or _flush_sentinel.exists():
_flush_flag[0] = False
try:
_flush_sentinel.unlink()
except FileNotFoundError:
pass
_ck_name = f"{phase_name}_step{step:08d}.pt"
_flush_delta()
_prune_checkpoints(pathlib.Path(args.save_dir), phase_name, max_ckpts)
save_ckpt(pathlib.Path(args.save_dir) / _ck_name, core, ar_h, sat_h, nat_h, opt, scaler,
meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights})
last_save_mono = time.monotonic()
_prune_deltas(pathlib.Path(args.save_dir), phase_name, args.delta_max_keep)
last_delta_step = step
print(f"[{phase_name}] ON-DEMAND flush saved {_ck_name} at step {step}")
if args.save_every_sec > 0:
now_mono = time.monotonic()
if now_mono - last_save_mono >= args.save_every_sec:
ck_name = f"{phase_name}_step{step:08d}.pt"
_flush_delta() # wait for any in-flight delta before full save
_prune_checkpoints(pathlib.Path(args.save_dir), phase_name, max_ckpts)
save_ckpt(pathlib.Path(args.save_dir) / ck_name, core, ar_h, sat_h, nat_h, opt, scaler,
meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights})
last_save_mono = now_mono
# Prune old deltas after a full save (they're superseded)
_prune_deltas(pathlib.Path(args.save_dir), phase_name, args.delta_max_keep)
last_delta_step = step # reset delta counter after full save
# ── Delta checkpoint (step-based, weight-only, async) ──
if args.delta_every_steps > 0 and (step - last_delta_step) >= args.delta_every_steps:
save_root = pathlib.Path(args.save_dir)
# AGILLM4 production runs on small rented disks. When keep=1, prune
# old deltas before the async writer creates the next multi-GB file.
if args.delta_max_keep and args.delta_max_keep > 0:
_flush_delta()
_prune_delta_files_to_count(save_root, phase_name, args.delta_max_keep - 1)
save_delta(core, ar_h, sat_h, nat_h, step, seen_tok, save_root, phase_name)
last_delta_step = step
if args.auto_grow:
steps_since_last_grow += 1
if steps_since_last_grow >= args.grow_every_steps:
steps_since_last_grow = 0
try:
idx = grow_plan.index(BLOCK)
if idx + 1 < len(grow_plan):
BLOCK = grow_plan[idx + 1]
print(f"[{phase_name} Grow] Block -> {BLOCK}")
if DEV.type == "cuda": torch.cuda.empty_cache()
except ValueError:
grow_plan = sorted(set(grow_plan + [BLOCK]))
pbar.close()
_flush_delta() # ensure any in-flight delta completes before final save
if phase_name != "sft":
save_ckpt(pathlib.Path(args.save_dir) / f"{phase_name}_final.pt", core, ar_h, sat_h, nat_h, opt, scaler,
meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights})
else:
print("[sft] Skipping duplicate sft_final.pt; final.pt will contain the SFT result.")
return step, seen_tok, time.time()
# ───────────────────────── Main Orchestrator ─────────────────────────
def train(args):
if getattr(args, "agillm3_compat", False):
args.no_nat_head = True
args.nat_every = 0
args.dblock_nat_weight = 0.0
args.dblock_nat_prob = 0.0
args.reinit_nat = False
args.seed_nat_from_ar = False
print(f"[agillm4.1] legacy compatibility mode: tokenizer={TOKENIZER_ID}, AR+SAT checkpoint schema, NAT disabled")
cfg = PRESETS[args.preset].copy()
tie_weights = args.tie_weights
print_expansion_info(cfg, tie_weights)
if not args.fresh:
if args.warmstart_from:
src_probe = pathlib.Path(args.warmstart_from)
elif args.resume:
src_probe = pathlib.Path(args.resume)
else:
src_probe = pathlib.Path(args.save_dir) / "final.pt"
prev_cfg = infer_cfg_from_ckpt(src_probe)
else: prev_cfg = None
if prev_cfg:
cfg.update({k: v for k, v in prev_cfg.items() if k in cfg})
if args.x2 and prev_cfg.get("layers"): cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2)
if args.rank: cfg["rank"] = args.rank
if args.x2 and not prev_cfg: cfg["layers"] *= 2
prev_moe = prev_cfg if isinstance(prev_cfg, dict) else {}
requested_moe = bool(getattr(args, "moe_ffn", DEFAULT_MOE_FFN))
if requested_moe or bool(prev_moe.get("moe_ffn", False)):
cfg["moe_ffn"] = True
cfg["moe_experts"] = int(getattr(args, "moe_experts", DEFAULT_MOE_EXPERTS) if requested_moe else prev_moe.get("moe_experts", DEFAULT_MOE_EXPERTS))
cfg["moe_top_k"] = int(getattr(args, "moe_top_k", DEFAULT_MOE_TOP_K) if requested_moe else prev_moe.get("moe_top_k", DEFAULT_MOE_TOP_K))
cfg["moe_mlp_mult"] = int(getattr(args, "moe_mlp_mult", DEFAULT_MOE_MLP_MULT) if requested_moe else prev_moe.get("moe_mlp_mult", DEFAULT_MOE_MLP_MULT))
else:
cfg["moe_ffn"] = False
use_nat_head = not bool(getattr(args, "no_nat_head", False))
if not use_nat_head:
cfg["nat_head"] = False
args.nat_every = 0
args.dblock_nat_weight = 0.0
args.dblock_nat_prob = 0.0
print(f"Config: {cfg}")
print(
"AGILLM4.1 single-file runtime: "
f"attn_backend={args.attn_backend} grad_checkpoint={args.grad_checkpoint} "
f"sublinear_window={args.sublinear_window} sublinear_stride={args.sublinear_stride} "
f"sublinear_max_anchors={args.sublinear_max_anchors} sublinear_chunk={args.sublinear_chunk} "
f"sublinear_sinks={args.sublinear_sinks} sublinear_recent_anchors={args.sublinear_recent_anchors} "
f"sublinear_pooled_landmarks={args.sublinear_pooled_landmarks} "
f"moe_ffn={cfg.get('moe_ffn', False)} moe_experts={cfg.get('moe_experts', 0)} "
f"moe_top_k={cfg.get('moe_top_k', 0)} moe_mlp_mult={cfg.get('moe_mlp_mult', 0)}"
)
core = Encoder(
cfg,
tie_weights=tie_weights,
attn_backend=args.attn_backend,
grad_checkpoint=args.grad_checkpoint,
sublinear_window=args.sublinear_window,
sublinear_stride=args.sublinear_stride,
sublinear_max_anchors=args.sublinear_max_anchors,
sublinear_chunk=args.sublinear_chunk,
sublinear_sinks=args.sublinear_sinks,
sublinear_recent_anchors=args.sublinear_recent_anchors,
sublinear_pooled_landmarks=args.sublinear_pooled_landmarks,
anchor_memory=getattr(args, "anchor_memory", DEFAULT_ANCHOR_MEMORY),
anchor_stride=getattr(args, "anchor_stride", DEFAULT_ANCHOR_STRIDE),
anchor_max=getattr(args, "anchor_max", DEFAULT_ANCHOR_MAX),
anchor_position=getattr(args, "anchor_position", DEFAULT_ANCHOR_POSITION),
).to(DEV)
ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV)
sat_h = SATHead(cfg["d"], mode="var", tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV)
nat_h = NATHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV) if use_nat_head else None
total_params = _count_enabled_params(core, ar_h, sat_h, nat_h)
print(f"Total parameters: {total_params:,}")
if tie_weights:
head_names = "AR/SAT/NAT" if nat_h is not None else "AR/SAT"
print(f"{Colors.WARN}[weight-tying] Embedding and {head_names} vocab projections share one tensor (VRAM-first){Colors.RESET}")
if not args.fresh:
src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt"
src = _resolve_ckpt(src)
if src:
loaded = _safe_load_any(src, core, key="core")
_safe_load_any(src, ar_h, key="ar")
_safe_load_any(src, sat_h, key="sat")
nat_loaded = _safe_load_any(src, nat_h, key="nat") if nat_h is not None else 0
if nat_h is not None and not nat_loaded:
print("[nat] Warm-start source has no NAT head; NAT head initialized fresh")
if loaded: print(f"Warm-start loaded from {src}")
_phase_freeze(core, freeze_core=args.freeze_core, unfreeze_ln=args.unfreeze_ln, train_emb=args.train_emb)
opt = make_optimizer(args, core, ar_h, sat_h, args.lr_core, args.lr_head, nat_h)
scaler = GradScaler(enabled=(args.amp and _needs_grad_scaler()))
start_step, seen_tok, last_wall = 0, 0, None
if args.resume_delta and not args.fresh:
delta_step, delta_tok = load_delta(pathlib.Path(args.resume_delta), core, ar_h, sat_h, nat_h)
start_step, seen_tok, last_wall = delta_step, delta_tok, None
print(f"Resumed from DELTA at step {start_step} (optimizer state reset — momentum rebuilds in ~100 steps)")
elif args.resume and not args.fresh:
start_step, seen_tok, last_wall = load_ckpt(pathlib.Path(args.resume), core, ar_h, sat_h, opt, scaler, nat_h)
print(f"Resumed from step {start_step}")
if getattr(args, "seed_nat_from_ar", False) and nat_h is not None and ar_h is not None:
# Seed the non-autoregressive (NAT) head from the trained AR head ("father").
# Same hidden->vocab projection shape, so NAT starts knowing the token
# distribution instead of from random/blank -> faster, no collapse.
with torch.no_grad():
nat_h.proj.weight.copy_(ar_h.proj.weight)
if nat_h.proj.bias is not None:
if getattr(ar_h.proj, "bias", None) is not None:
nat_h.proj.bias.copy_(ar_h.proj.bias)
else:
nat_h.proj.bias.zero_()
print("[nat] Seeded NAT head from the AR head ('father') for the mask-predict objective")
elif getattr(args, "reinit_nat", False) and nat_h is not None:
for _m in nat_h.modules():
if isinstance(_m, nn.Linear):
nn.init.normal_(_m.weight, mean=0.0, std=0.02)
if _m.bias is not None:
nn.init.zeros_(_m.bias)
print("[nat] Reinitialized NAT head weights (random) for the mask-predict objective")
# torch.compile AFTER loading checkpoint (key names differ)
if args.compile:
print("[torch.compile] Compiling model...")
core = torch.compile(core, mode="reduce-overhead")
ar_h = torch.compile(ar_h, mode="reduce-overhead")
sat_h = torch.compile(sat_h, mode="reduce-overhead")
if nat_h is not None:
nat_h = torch.compile(nat_h, mode="reduce-overhead")
print("[torch.compile] Done.")
step, seen_tok, last_wall = _train_phase(
args, "pretrain", core, ar_h, sat_h, nat_h, opt, scaler,
start_step, seen_tok, last_wall, cfg,
args.source, args.steps,
args.block or DEFAULT_BLOCK,
args.batch_size or DEFAULT_BATCH,
chat_cfg={"chat": args.chat, "key": args.chat_messages_key, "gen_prompt": args.sft_add_generation_prompt, "text_field": args.dataset_field_text},
max_ckpts=args.max_ckpts,
target_tokens_override=args.target_tokens,
tie_weights=tie_weights
)
if (not args.after_sft_source) and (args.after_sft_steps and args.after_sft_steps > 0):
args.after_sft_source = DEFAULT_AFTER_SFT_SOURCES
args.after_sft_chat = True
if args.after_sft_add_generation_prompt is None: args.after_sft_add_generation_prompt = True
if not args.after_sft_block: args.after_sft_block = DEFAULT_AFTER_SFT_BLOCK
if args.after_sft_source and args.after_sft_steps and args.after_sft_steps > 0:
print("\n[Orchestrator] Starting Post-Pretraining SFT Phase...")
_phase_freeze(core,
freeze_core=args.after_sft_freeze_core,
unfreeze_ln=args.after_sft_unfreeze_ln,
train_emb=args.after_sft_train_emb)
opt = make_optimizer(
args,
core,
ar_h,
sat_h,
args.after_sft_lr_core or args.lr_core,
args.after_sft_lr_head or args.lr_head,
nat_h,
)
step, seen_tok, last_wall = _train_phase(
args, "sft", core, ar_h, sat_h, nat_h, opt, scaler,
step, seen_tok, last_wall, cfg,
args.after_sft_source, args.after_sft_steps,
args.after_sft_block or DEFAULT_AFTER_SFT_BLOCK,
args.batch_size or DEFAULT_BATCH,
chat_cfg={
"chat": args.after_sft_chat,
"key": args.after_sft_chat_messages_key,
"gen_prompt": args.after_sft_add_generation_prompt if args.after_sft_add_generation_prompt is not None else args.sft_add_generation_prompt,
"text_field": args.after_sft_dataset_field_text
},
max_ckpts=args.max_ckpts,
target_tokens_override=None,
tie_weights=tie_weights,
streaming=True
)
save_ckpt(pathlib.Path(args.save_dir) / "final.pt", core, ar_h, sat_h, nat_h, opt, scaler,
meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights})
print("🎉 All Training Complete")
# ───────────────────────── Sampling ─────────────────────────
def _apply_penalties(logits, ids, n, rep_p, pres_p, freq_p):
if ids.numel() == 0: return logits
hist = ids[0, -n:].long() if n > 0 else ids[0].long()
uniq, counts = torch.unique(hist, return_counts=True)
if pres_p or freq_p:
logits[..., uniq] -= (pres_p + freq_p * counts.float())
if rep_p != 1.0:
sel = logits[..., uniq]
logits[..., uniq] = torch.where(sel > 0, sel / rep_p, sel * rep_p)
return logits
def _sample(logits, T, top_k, top_p, min_p, greedy):
if greedy: return logits.argmax(-1, keepdim=True)
probs = (logits / max(T, 1e-8)).softmax(-1)
if top_k:
v, i = torch.topk(probs, min(top_k, probs.size(-1)))
probs = torch.zeros_like(probs).scatter_(-1, i, v)
if top_p < 1.0:
s_probs, s_idx = torch.sort(probs, descending=True, dim=-1)
probs = torch.zeros_like(probs).scatter_(-1, s_idx, s_probs * (torch.cumsum(s_probs, -1) <= top_p).float())
if min_p > 0: probs[probs < min_p] = 0
if probs.sum() == 0: return logits.argmax(-1, keepdim=True)
return probs.div_(probs.sum()).multinomial(1)
def _dblock_block_layers(core, dblock_blocks):
L = len(core.blocks)
B = max(1, int(dblock_blocks))
per = max(1, L // B)
groups = []
for b in range(B):
lo = b * per
hi = L if b == B - 1 else (b + 1) * per
groups.append(list(range(lo, hi)))
return groups
def _dblock_select_block(sigma, bsig):
for b in range(len(bsig) - 1):
if bsig[b] <= sigma <= bsig[b + 1]:
return b
return 0 if sigma < bsig[0] else len(bsig) - 2
def _edm_denoise_block(core, layers, z, sigma_t, mask, args):
cs, co, ci = _edm_pre(sigma_t)
h = ci * z
for li in layers:
h = _run_block(core.blocks[li], h, mask, False, args)
return cs * z + co * h
@torch.no_grad()
def _dblock_euler_hidden(core, ids, args):
"""DiffusionBlocks EDM Euler block-chain hidden state (faithful reverse ODE),
adapted to agillm4.1's causal AR head. --euler_start_sigma tunes context
conditioning (SDEdit-style); returns LayerNorm'd hidden [B,T,d]."""
import numpy as _np
dblock_blocks = int(getattr(args, "dblock_blocks", 4) or 4)
steps = max(dblock_blocks, int(getattr(args, "euler_steps", 0) or (dblock_blocks * 2)))
bsig = _block_sigmas(dblock_blocks)
groups = _dblock_block_layers(core, dblock_blocks)
sigma_min = float(bsig[0])
start = float(getattr(args, "euler_start_sigma", 0.0) or 0.0)
if start <= 0.0:
start = float(bsig[-1])
start = max(start, sigma_min * 2)
mask = causal_mask(ids.size(1), structured=use_structured_masks(args))
e = core.emb(ids)
lo, hi = math.log(sigma_min), math.log(start)
sched = [float(_np.exp(hi + (lo - hi) * (i / steps))) for i in range(steps + 1)]
z = e + sched[0] * torch.randn_like(e)
with amp(getattr(args, "amp", False)):
for i in range(steps):
s_cur, s_next = sched[i], sched[i + 1]
b = _dblock_select_block(s_cur, bsig)
sig_t = torch.full((ids.size(0),), s_cur, device=ids.device, dtype=z.dtype)
D = _edm_denoise_block(core, groups[b], z, sig_t, mask, args)
z = z + ((s_next - s_cur) / s_cur) * (z - D)
sig0 = torch.full((ids.size(0),), sigma_min, device=ids.device, dtype=z.dtype)
D0 = _edm_denoise_block(core, groups[0], z, sig0, mask, args)
return core.ln(D0)
@torch.no_grad()
def infer(args):
if args.mode == "ar":
if args.temperature is None: args.temperature = 0.7
if args.top_k is None: args.top_k = 0
if args.repetition_penalty is None: args.repetition_penalty = 1.3
if args.presence_penalty is None: args.presence_penalty = 0.0
if args.frequency_penalty is None: args.frequency_penalty = 0.3
if args.penalty_last_n is None: args.penalty_last_n = 128
if args.var is None: args.var = False
elif args.mode == "sat":
if args.temperature is None: args.temperature = 0.5
if args.top_k is None: args.top_k = 30
if args.repetition_penalty is None: args.repetition_penalty = 2.0
if args.presence_penalty is None: args.presence_penalty = 0.6
if args.frequency_penalty is None: args.frequency_penalty = 1.0
if args.penalty_last_n is None: args.penalty_last_n = 200
if args.var is None: args.var = True
else:
if args.temperature is None: args.temperature = 0.8
if args.top_k is None: args.top_k = 50
if args.repetition_penalty is None: args.repetition_penalty = 1.6
if args.presence_penalty is None: args.presence_penalty = 0.6
if args.frequency_penalty is None: args.frequency_penalty = 1.0
if args.penalty_last_n is None: args.penalty_last_n = 512
if args.var is None: args.var = False
path = _resolve_ckpt(pathlib.Path(args.ckpt)) or pathlib.Path(args.ckpt)
sd = torch.load(path, map_location="cpu")
# Restore tokenizer from checkpoint if available
if "tokenizer_json" in sd:
try:
from tokenizers import Tokenizer as _Tokenizer
tok.backend_tokenizer = _Tokenizer.from_str(sd["tokenizer_json"])
print("[tokenizer] Restored from checkpoint")
except Exception as e:
print(f"[tokenizer] WARNING: could not restore from checkpoint: {e}")
# Warn if transformers version changed since checkpoint was saved
if "transformers_version" in sd:
import transformers as _tf
if sd["transformers_version"] != _tf.__version__:
print(f"[tokenizer] WARNING: checkpoint saved with transformers={sd['transformers_version']}, now running {_tf.__version__}")
# Handle delta checkpoints (weight-only, no cfg)
if sd.get("delta"):
print("[infer] Delta checkpoint detected, using large preset cfg")
cfg = PRESETS["large"].copy()
tie_weights = False
# Remap: delta stores under sd["weights"]["core"/"ar"/"sat"/"nat"]
sd["core"] = sd["weights"]["core"]
sd["ar"] = sd["weights"]["ar"]
sd["sat"] = sd["weights"]["sat"]
if "nat" in sd["weights"]:
sd["nat"] = sd["weights"]["nat"]
else:
cfg = sd["cfg"]
tie_weights = sd.get("tie_weights", False)
plain_output = (
bool(getattr(args, "plain_output", False))
or bool(getattr(args, "claude_friendly", False))
or not sys.stdout.isatty()
)
uk_time = get_uk_time()
ckpt_name = path.name
if plain_output:
print(f"[infer] inference_time={uk_time}")
print(f"[infer] checkpoint={ckpt_name}")
else:
print(f"┌─────────────────────────────────────────────────┐")
print(f"│ INFERENCE @ {uk_time:<35s} │")
print(f"├─────────────────────────────────────────────────┤")
print(f"│ Checkpoint: {ckpt_name:<35s} │")
print(f"└─────────────────────────────────────────────────┘")
print_expansion_info(cfg, tie_weights, plain=plain_output)
core = Encoder(
cfg,
tie_weights=tie_weights,
attn_backend=args.attn_backend,
sublinear_window=args.sublinear_window,
sublinear_stride=args.sublinear_stride,
sublinear_max_anchors=args.sublinear_max_anchors,
sublinear_chunk=args.sublinear_chunk,
sublinear_sinks=args.sublinear_sinks,
sublinear_recent_anchors=args.sublinear_recent_anchors,
sublinear_pooled_landmarks=args.sublinear_pooled_landmarks,
anchor_memory=getattr(args, "anchor_memory", DEFAULT_ANCHOR_MEMORY),
anchor_stride=getattr(args, "anchor_stride", DEFAULT_ANCHOR_STRIDE),
anchor_max=getattr(args, "anchor_max", DEFAULT_ANCHOR_MAX),
anchor_position=getattr(args, "anchor_position", DEFAULT_ANCHOR_POSITION),
).to(DEV)
ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV)
sat_head_mlp = bool(sd.get("sat_head_mlp", False) or _sat_head_mlp_from_state(sd))
sat_h = SATHead(cfg["d"], mlp=sat_head_mlp).to(DEV)
nat_h = NATHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV) if ("nat" in sd or args.mode == "nat") else None
core.load_state_dict(_prepare_core_state_dict_for_load(core, sd["core"]))
ar_h.load_state_dict(sd["ar"])
_load_infer_head_state(sat_h, sd["sat"], "SATHead")
if nat_h is not None:
if "nat" not in sd:
raise ValueError("NAT inference requested, but this checkpoint has no NAT head")
_load_infer_head_state(nat_h, sd["nat"], "NATHead")
core.eval()
ar_h.eval()
sat_h.eval()
if nat_h is not None:
nat_h.eval()
total_params = _count_enabled_params(core, ar_h, sat_h, nat_h)
if total_params >= 1_000_000_000:
param_str = f"{total_params / 1_000_000_000:.2f}B"
elif total_params >= 1_000_000:
param_str = f"{total_params / 1_000_000:.2f}M"
elif total_params >= 1_000:
param_str = f"{total_params / 1_000:.2f}K"
else:
param_str = f"{total_params}"
print(f"Model size: {param_str} parameters ({total_params:,})")
prompt_tokens = tok.encode(args.prompt)
prompt_len = len(prompt_tokens)
ids = torch.tensor([prompt_tokens], device=DEV)
if ids.size(1) == 0:
ids = torch.tensor([[EOS]], device=DEV)
prompt_len = 1
mode_str = args.mode
if args.mode == "sat":
mode_str = f"sat-{'var' if args.var else 'fixed'}"
if plain_output:
print(f"Generating ({mode_str})...")
else:
print(f"{Colors.INFO}Generating ({mode_str})...{Colors.RESET}")
start = time.time()
if args.mode == "ar":
_euler = getattr(args, "sampler", "ar") == "euler"
if not _euler:
h, kvs = core(ids, causal_mask(ids.size(1), structured=use_structured_masks(args)), use_cache=True, total_seq_len=ids.size(1))
for _ in range(args.max_new):
if _euler:
h = _dblock_euler_hidden(core, ids, args)
logits = ar_h(h)[:, -1]
logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
ids = torch.cat([ids, nxt], 1)
if EOS is not None and int(nxt.item()) == int(EOS):
break
if not _euler:
h, kvs = core(ids[:, -1:], None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
elif args.mode == "nat":
# Iterative mask-predict decode (CMLM): keep the prompt fixed and fill the
# BLANK slots, committing confident predictions each pass. Unlike the
# original straight argmax path, this applies the same anti-repetition
# penalties and sampler used by AR/SAT at each committed position.
n_fill = max(1, int(args.max_new))
ids = torch.tensor([prompt_tokens + [BLANK] * n_fill], device=DEV)
remaining = set(range(prompt_len, prompt_len + n_fill))
passes = max(1, int(args.nat_passes))
def _nat_history(current_ids: torch.Tensor):
keep = current_ids[0] != BLANK
if bool(keep.any()):
return current_ids[:, keep]
return current_ids[:, :max(1, prompt_len)]
def _nat_pick(logits_pos: torch.Tensor, current_ids: torch.Tensor):
logits_pos = logits_pos.clone()
logits_pos[..., BLANK] = -1e9
logits_pos = _apply_penalties(
logits_pos,
_nat_history(current_ids),
args.penalty_last_n,
args.repetition_penalty,
args.presence_penalty,
args.frequency_penalty,
)
return _sample(logits_pos, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
for p in range(passes):
if not remaining:
break
h = core(ids, None)
logits = nat_h(h)
logits[..., BLANK] = -1e9
conf = logits.softmax(-1).amax(-1)
k = max(1, -(-len(remaining) // (passes - p)))
ordered = sorted(remaining, key=lambda q: float(conf[0, q]), reverse=True)[:k]
for pos in ordered:
nxt = _nat_pick(logits[:, pos, :], ids)
ids[0, pos] = int(nxt.reshape(-1)[0])
remaining.discard(pos)
if remaining:
h = core(ids, None)
logits = nat_h(h)
logits[..., BLANK] = -1e9
for pos in sorted(remaining):
nxt = _nat_pick(logits[:, pos, :], ids)
ids[0, pos] = int(nxt.reshape(-1)[0])
else:
cached_len = ids.size(1)
h, kvs = core(ids, sat_mask(ids.size(1), structured=use_structured_masks(args)), use_cache=True, total_seq_len=cached_len)
h_buffer = h[:, -SAT_BLOCK:]
added = 0
stop = False
# Align to block boundary if prompt is off-boundary
if ids.size(1) % SAT_BLOCK != 0:
logits = ar_h(h)[:, -1]
logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
ids = torch.cat([ids, nxt], 1)
added += 1
if EOS is not None and int(nxt.item()) == int(EOS):
stop = True
if not stop:
h, kvs = core(nxt, None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
cached_len = ids.size(1)
h_buffer = torch.cat([h_buffer, h], dim=1)[:, -SAT_BLOCK:]
while added < args.max_new and not stop:
logits_all, gate = sat_h(h_buffer)
stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1)
stride = min(int(stride), logits_all.size(1))
new_tokens = []
for i in range(int(stride)):
logits = logits_all[:, i]
logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty)
nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy)
new_tokens.append(nxt)
ids = torch.cat([ids, nxt], 1)
added += 1
if EOS is not None and int(nxt.item()) == int(EOS):
stop = True
break
if added >= args.max_new: break
if stop or added >= args.max_new: break
new_ids = torch.cat(new_tokens, dim=1)
mask = sat_mask_cached(new_ids.size(1), cached_len, structured=use_structured_masks(args))
h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1))
cached_len = ids.size(1)
h_buffer = torch.cat([h_buffer, h], dim=1)[:, -SAT_BLOCK:]
elapsed = time.time() - start
gen_tokens = len(ids[0]) - prompt_len
tok_per_sec = gen_tokens / elapsed if elapsed > 0 else 0
all_tokens = ids[0].tolist()
prompt_text = tok.decode(all_tokens[:prompt_len], skip_special_tokens=True)
gen_text = tok.decode(all_tokens[prompt_len:], skip_special_tokens=True)
safe_prompt = _ascii_safe(prompt_text) if plain_output else prompt_text
safe_gen = _ascii_safe(gen_text) if plain_output else gen_text
if plain_output:
print(f"{safe_prompt}{safe_gen}")
print(f"[{elapsed:.2f}s | {gen_tokens} tokens | {tok_per_sec:.1f} tok/s]")
else:
print(f"{Colors.PROMPT}{safe_prompt}{Colors.RESET}{safe_gen}")
print(f"{Colors.INFO}[{elapsed:.2f}s | {gen_tokens} tokens | {tok_per_sec:.1f} tok/s]{Colors.RESET}")
if getattr(args, "claude_friendly", False):
claude_prompt = _ascii_safe(prompt_text)
claude_gen = _ascii_safe(gen_text)
print("[CLAUDE_FRIENDLY_START]")
print(f"[mode={mode_str}]")
print("[prompt_input]")
print(claude_prompt)
print("[completion]")
print(claude_gen)
print("[prompt_plus_completion]")
print(f"{claude_prompt}{claude_gen}")
print(f"[stats] {elapsed:.2f}s | {gen_tokens} tokens | {tok_per_sec:.1f} tok/s")
print("[CLAUDE_FRIENDLY_END]")
# ───────────────────────── CLI ─────────────────────────
def main():
ap = argparse.ArgumentParser(description="AGILLM Expansion Ratio Testing")
sub = ap.add_subparsers(dest="cmd", required=True)
tr = sub.add_parser("train")
tr.add_argument("--preset", choices=PRESETS.keys(), default="large")
tr.add_argument("--rank", type=int)
tr.add_argument("--block", type=int, default=DEFAULT_BLOCK)
tr.add_argument("--batch_size", type=int, default=DEFAULT_BATCH)
tr.add_argument("--source", default=DEFAULT_PRETRAIN_SOURCES)
tr.add_argument("--target_tokens", type=int)
tr.add_argument("--token_param_ratio", type=float, default=0.0,
help="If --target_tokens is omitted, train to this tokens:param ratio. AGILLM-4 presets default to 100.")
tr.add_argument("--steps", type=int)
tr.add_argument("--amp", action="store_true")
tr.add_argument("--compile", action="store_true", help="Use torch.compile for speedup")
tr.add_argument("--attn_backend", choices=["manual", "sdpa", "sublinear"], default=DEFAULT_ATTN_BACKEND,
help="AGILLM-4 attention backend. sublinear uses local-window plus landmark candidates.")
tr.add_argument("--grad_checkpoint", action="store_true",
help="Recompute transformer blocks during backward to trade speed for longer context.")
tr.add_argument("--sublinear_window", type=int, default=DEFAULT_SUBLINEAR_WINDOW,
help="For --attn_backend sublinear, attend to this many local tokens on each side.")
tr.add_argument("--sublinear_stride", type=int, default=DEFAULT_SUBLINEAR_STRIDE,
help="For --attn_backend sublinear, use every Nth token as a landmark candidate.")
tr.add_argument("--sublinear_max_anchors", type=int, default=DEFAULT_SUBLINEAR_MAX_ANCHORS,
help="For --attn_backend sublinear, cap landmark candidates per query chunk.")
tr.add_argument("--sublinear_chunk", type=int, default=DEFAULT_SUBLINEAR_CHUNK,
help="For --attn_backend sublinear, query chunk size controlling peak gather memory.")
tr.add_argument("--sublinear_sinks", type=int, default=DEFAULT_SUBLINEAR_SINKS,
help="For sublinear attention, always include this many first-token attention sinks.")
tr.add_argument("--sublinear_recent_anchors", type=int, default=DEFAULT_SUBLINEAR_RECENT_ANCHORS,
help="For capped sublinear anchors, reserve this many anchors for the recent tail; -1 uses half.")
tr.add_argument("--sublinear_pooled_landmarks", action=argparse.BooleanOptionalAction,
default=DEFAULT_SUBLINEAR_POOLED_LANDMARKS,
help="Use stride-segment pooled K/V summaries for sublinear landmark anchors.")
tr.add_argument("--no_structured_masks", action="store_true",
help="Disable structured causal/SAT masks for sublinear attention and fall back to dense masks.")
tr.add_argument("--anchor_memory", action="store_true",
help="Enable anchor-memory long-context augmentation (one AnchorMemoryLayer at mid-stack).")
tr.add_argument("--anchor_stride", type=int, default=DEFAULT_ANCHOR_STRIDE,
help="Token span compressed into one anchor (default 256).")
tr.add_argument("--anchor_max", type=int, default=DEFAULT_ANCHOR_MAX,
help="Max anchors retained in the rolling memory bank.")
tr.add_argument("--anchor_position", type=int, default=DEFAULT_ANCHOR_POSITION,
help="Block index after which to insert anchor memory (-1 = stack middle).")
tr.add_argument("--kv_buffer", action="store_true",
help="Use preallocated KV buffer instead of torch.cat-based cache growth.")
tr.add_argument("--optimizer", choices=["adamw", "adamw8bit", "paged_adamw8bit"], default="adamw",
help="Optimizer backend. 8-bit options reduce VRAM on 24GB production runs.")
tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC)
tr.add_argument("--heartbeat_every_sec", type=int, default=300,
help="Print lightweight trainer heartbeat/status lines every N seconds; 0 disables.")
tr.add_argument("--empty_cache_every_steps", type=int, default=0,
help="Call torch.cuda.empty_cache() every N train steps; useful for VRAM-first runs where lower reserved VRAM matters more than speed.")
tr.add_argument("--profile_steps", type=int, default=0,
help="Profile the first N DBlock training steps with in-process CUDA timers; 0 disables.")
tr.add_argument("--profile_log_every", type=int, default=25,
help="Print averaged profiler timings every N profiled steps.")
tr.add_argument("--delta_every_steps", type=int, default=DEFAULT_DELTA_STEPS, help="Weight-only delta save every N steps (0=off)")
tr.add_argument("--delta_max_keep", type=int, default=DEFAULT_MAX_DELTAS, help="Max delta checkpoints to keep")
tr.add_argument("--resume_delta", type=str, help="Resume from a delta (weight-only, no optimizer state)")
tr.add_argument("--async_update_dir", default="",
help="Optional incoming directory for verified DBlock side updates. Empty disables async side updates.")
tr.add_argument("--async_update_every_steps", type=int, default=0,
help="Poll --async_update_dir every N master steps. Side workers never block master progress.")
tr.add_argument("--async_update_alpha", type=float, default=1.0,
help="Blend factor for accepted side updates: 1.0 copies side block weights; lower values lerp into live weights.")
tr.add_argument("--async_update_max_per_check", type=int, default=1,
help="Maximum side-update files to apply per poll.")
tr.add_argument("--async_update_max_age_sec", type=float, default=0.0,
help="Reject incoming side updates older than this many seconds. 0 disables age rejection.")
tr.add_argument("--async_update_accepted_dir", default="",
help="Directory for applied side-update files. Defaults to a sibling accepted/ directory.")
tr.add_argument("--async_update_rejected_dir", default="",
help="Directory for rejected side-update files. Defaults to a sibling rejected/ directory.")
tr.add_argument("--save_dir", default=str(CKDIR))
tr.add_argument("--resume", type=str)
tr.add_argument("--x2", action="store_true")
tr.add_argument("--warmstart_from", type=str)
tr.add_argument("--fresh", action="store_true")
tr.add_argument("--max_ckpts", type=int, default=None)
tr.add_argument("--chilla_max_double", action="store_true")
tr.add_argument("--tie_weights", action="store_true")
tr.add_argument("--ar_only", action="store_true")
tr.add_argument("--agillm3_compat", action="store_true",
help="Legacy AGILLM3/3.5 checkpoint mode. Use TOKENIZER_ID=deepseek-ai/DeepSeek-V3.2 or the agillm35.py shim for the old tokenizer contract.")
tr.add_argument("--no_nat_head", action="store_true",
help="Do not instantiate/save a NAT head. Keeps AGILLM3 AR+SAT checkpoint schema and reduces params/RAM.")
tr.add_argument("--sat_every", type=int, default=1,
help="Train SAT every N steps. Default 1 keeps AR+SAT every step.")
tr.add_argument("--nat_every", type=int, default=1,
help="Train NAT every N steps with a CTC objective. Default 1 keeps AR+SAT+NAT every step.")
tr.add_argument("--nat_loss_weight", type=float, default=1.0)
tr.add_argument("--nat_expand", type=int, default=2,
help="Repeat tokens this many times for the NAT CTC input length.")
tr.add_argument("--nat_max_tokens", type=int, default=0,
help="Optional cap for NAT target tokens per batch; 0 uses the whole block.")
tr.add_argument("--nat_mask_ratio", type=float, default=0.5,
help="Fraction of positions masked to BLANK for the NAT mask-predict (CMLM) objective.")
tr.add_argument("--moe_ffn", action=argparse.BooleanOptionalAction, default=DEFAULT_MOE_FFN,
help="Use Mixture-of-Experts feed-forward layers inside the transformer blocks.")
tr.add_argument("--moe_experts", type=int, default=DEFAULT_MOE_EXPERTS,
help="Number of FFN experts per transformer block when --moe_ffn is enabled.")
tr.add_argument("--moe_top_k", type=int, default=DEFAULT_MOE_TOP_K,
help="Router top-k experts per token when --moe_ffn is enabled.")
tr.add_argument("--moe_mlp_mult", type=int, default=DEFAULT_MOE_MLP_MULT,
help="Expert hidden-size multiplier; 4 preserves dense FFN checkpoint shape for seeding.")
tr.add_argument("--dblock", action="store_true", help="DiffusionBlocks block-wise denoising training (low VRAM).")
tr.add_argument("--dblock_blocks", type=int, default=4, help="Partition layers into this many DiffusionBlocks blocks.")
tr.add_argument("--dblock_schedule", choices=["random", "roundrobin", "loss_balanced"], default="loss_balanced",
help="How --dblock chooses the next layer block. loss_balanced focuses blocks whose EMA loss is highest after warmup.")
tr.add_argument("--dblock_warmup_steps", type=int, default=16,
help="Initial DBlock steps spent covering every block before loss-balanced scheduling.")
tr.add_argument("--dblock_explore", type=float, default=0.05,
help="Exploration rate for loss-balanced DBlock scheduling.")
tr.add_argument("--dblock_log_every", type=int, default=25,
help="Print DBlock block/loss/VRAM diagnostics every N DBlock steps; 0 disables.")
tr.add_argument("--dblock_checkpoint_stride", type=int, default=1,
help="With --grad_checkpoint in --dblock mode, checkpoint one layer every N selected block layers; 1=all layers, 2=alternate, 0=off.")
tr.add_argument("--dblock_checkpoint_skip_tail", type=int, default=0,
help="Experimental DBlock speed knob: do not checkpoint this many final layers in the selected block, reducing backward recompute at higher VRAM cost.")
tr.add_argument("--dblock_activation_offload", action="store_true",
help="Experimental DBlock speed knob: for non-checkpointed block layers, offload saved backward tensors to CPU RAM instead of recomputing.")
tr.add_argument("--dblock_activation_offload_min_mb", type=float, default=1.0,
help="Minimum CUDA tensor size in MB to offload under --dblock_activation_offload.")
tr.add_argument("--dblock_sigma_curriculum_steps", type=int, default=2000,
help="Warm sigma ranges from easy to full span over this many DBlock steps; 0 disables.")
tr.add_argument("--dblock_edm_wmax", type=float, default=5.0,
help="Cap for EDM loss weighting in DBlock mode.")
tr.add_argument("--dblock_ar_weight", type=float, default=1.0)
tr.add_argument("--dblock_sat_weight", type=float, default=1.0)
tr.add_argument("--dblock_nat_weight", type=float, default=1.0)
tr.add_argument("--dblock_objective_mode", choices=["periodic", "stochastic"], default="periodic",
help="DBlock objective scheduler. stochastic samples one objective per step to reduce redundant AR/SAT/NAT forwards.")
tr.add_argument("--dblock_ar_prob", type=float, default=0.80, help="Stochastic DBlock probability for AR objective.")
tr.add_argument("--dblock_sat_prob", type=float, default=0.10, help="Stochastic DBlock probability for SAT objective.")
tr.add_argument("--dblock_nat_prob", type=float, default=0.10, help="Stochastic DBlock probability for NAT objective.")
tr.add_argument("--dblock_ar_loss_tokens", type=int, default=0,
help="If >0, uniformly sample this many AR target positions per DBlock step for stochastic token-level CE.")
tr.add_argument("--dblock_sat_loss_tokens", type=int, default=0,
help="If >0, uniformly sample this many SAT target positions per DBlock step.")
tr.add_argument("--dblock_nat_loss_tokens", type=int, default=0,
help="If >0, uniformly sample this many NAT target positions per DBlock step.")
tr.add_argument("--reinit_nat", action="store_true",
help="Reinitialize NAT head weights after load (use once when switching to mask-predict).")
tr.add_argument("--seed_nat_from_ar", action="store_true",
help="Seed the NAT head from the trained AR head ('father') after load instead of random init.")
tr.add_argument("--freeze_core", action="store_true")
tr.add_argument("--unfreeze_ln", action="store_true")
tr.add_argument("--train_emb", action="store_true")
tr.add_argument("--lr_core", type=float, default=LR_CORE)
tr.add_argument("--lr_head", type=float, default=LR_HEAD)
tr.add_argument("--chat", action="store_true")
tr.add_argument("--chat_messages_key", default="messages")
tr.add_argument("--dataset_field_text", default="text")
tr.add_argument("--sft_add_generation_prompt", action="store_true")
tr.add_argument("--auto_grow", action="store_true")
tr.add_argument("--grow_plan", default="576,640,768,896,1024,1122")
tr.add_argument("--grow_every_steps", type=int, default=50000)
tr.add_argument("--after_sft_source", default="")
tr.add_argument("--after_sft_steps", type=int, default=0)
tr.add_argument("--after_sft_chat", action="store_true")
tr.add_argument("--after_sft_chat_messages_key", default="messages")
tr.add_argument("--after_sft_dataset_field_text", default="text")
tr.add_argument("--after_sft_add_generation_prompt", type=bool, default=None)
tr.add_argument("--after_sft_block", type=int, default=0)
tr.add_argument("--after_sft_freeze_core", action="store_true")
tr.add_argument("--after_sft_unfreeze_ln", action="store_true")
tr.add_argument("--after_sft_train_emb", action="store_true")
tr.add_argument("--after_sft_lr_core", type=float, default=0.0)
tr.add_argument("--after_sft_lr_head", type=float, default=0.0)
inf = sub.add_parser("infer")
inf.add_argument("--mode", choices=["ar", "sat", "nat"], required=True)
inf.add_argument("--sampler", choices=["ar", "euler"], default="ar", help="ar=KV decode; euler=DiffusionBlocks EDM Euler sampler.")
inf.add_argument("--euler_steps", type=int, default=0, help="Euler ODE steps (0=2x dblock_blocks).")
inf.add_argument("--euler_start_sigma", type=float, default=0.0, help="Euler start noise (0=sigma_max; lower=stronger context conditioning).")
inf.add_argument("--dblock_blocks", type=int, default=4, help="Number of DiffusionBlocks for the Euler sampler.")
inf.add_argument("--ckpt", required=True)
inf.add_argument("--prompt", required=True)
inf.add_argument("--max_new", type=int, default=120)
inf.add_argument("--temperature", type=float, default=None)
inf.add_argument("--greedy", action="store_true")
inf.add_argument("--top_k", type=int, default=None)
inf.add_argument("--top_p", type=float, default=0.9)
inf.add_argument("--min_p", type=float, default=0.0)
inf.add_argument("--repetition_penalty", type=float, default=None)
inf.add_argument("--presence_penalty", type=float, default=None)
inf.add_argument("--frequency_penalty", type=float, default=None)
inf.add_argument("--penalty_last_n", type=int, default=None)
inf.add_argument("--var", action="store_true", default=None)
inf.add_argument("--no-var", dest="var", action="store_false")
inf.add_argument("--claude-friendly", action="store_true", help="Also print an artifact-free prompt/completion block for downstream JSON consumers")
inf.add_argument("--plain-output", "--no-color", dest="plain_output", action="store_true", help="Use plain ASCII/no ANSI output for redirected inference logs")
inf.add_argument("--attn_backend", choices=["manual", "sdpa", "sublinear"], default=DEFAULT_ATTN_BACKEND)
inf.add_argument("--sublinear_window", type=int, default=DEFAULT_SUBLINEAR_WINDOW)
inf.add_argument("--sublinear_stride", type=int, default=DEFAULT_SUBLINEAR_STRIDE)
inf.add_argument("--sublinear_max_anchors", type=int, default=DEFAULT_SUBLINEAR_MAX_ANCHORS)
inf.add_argument("--sublinear_chunk", type=int, default=DEFAULT_SUBLINEAR_CHUNK)
inf.add_argument("--sublinear_sinks", type=int, default=DEFAULT_SUBLINEAR_SINKS)
inf.add_argument("--sublinear_recent_anchors", type=int, default=DEFAULT_SUBLINEAR_RECENT_ANCHORS)
inf.add_argument("--sublinear_pooled_landmarks", action=argparse.BooleanOptionalAction,
default=DEFAULT_SUBLINEAR_POOLED_LANDMARKS)
inf.add_argument("--no_structured_masks", action="store_true")
inf.add_argument("--nat_expand", type=int, default=2)
inf.add_argument("--nat_passes", type=int, default=1)
st = sub.add_parser("status", help="Read-only training status")
st.add_argument("--json", dest="json_output", action="store_true")
st.add_argument("--log", type=str, default=str(STATUS_DEFAULT_LOG))
st.add_argument("--save_dir", type=str, default=str(STATUS_DEFAULT_SAVE_DIR))
args = ap.parse_args()
if args.cmd == "train": train(args)
elif args.cmd == "infer": infer(args)
else: raise SystemExit(_emit_status(Path(args.log), Path(args.save_dir), args.json_output))
if __name__ == "__main__":
main()
# ===== END nB300_agillm4.py =====

Xet Storage Details

Size:
195 kB
·
Xet hash:
855e8a1252186aadaa8ca9eb783a7ec308c0ddcd62937505c3dc9f91e0510276

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.