Buckets:
| #!/usr/bin/env python3 | |
| """AGILLM4.1 mainline single-file trainer/inference runtime. | |
| AGILLM4.1 is the promoted AGILLM4 mainline evolved from the AGILLM3.5 | |
| prototype, and it is larger than AGILLM3/AGILLM3.5. Resumed checkpoints are | |
| the source of truth for the exact architecture, with AGILLM4 presets available | |
| for fresh starts. This file is mechanically folded from AGILLM4 plus | |
| compatibility patches: | |
| - DeepSeek-V4-Pro tokenizer/checkpoint support by default | |
| - DeepSeek-V3.2 legacy compatibility support through the agillm35 shim | |
| - AR + SAT checkpoint schema compatibility; NAT can be disabled with --agillm3_compat | |
| - DiffusionBlock training support and optional async side-update ingestion | |
| """ | |
| from __future__ import annotations | |
| # Single-file module alias: helper code still imports the historical module names. | |
| import sys as _agillm41_sys | |
| _agillm41_sys.modules.setdefault("nB300_agillm4", _agillm41_sys.modules[__name__]) | |
| _agillm41_sys.modules.setdefault("agillm35", _agillm41_sys.modules[__name__]) | |
| _agillm41_sys.modules.setdefault("agillm41", _agillm41_sys.modules[__name__]) | |
| # ===== BEGIN anchor_memory.py ===== | |
| #!/usr/bin/env python3 | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class AnchorMemoryConfig: | |
| d_model: int | |
| heads: int | |
| anchor_stride: int = 256 | |
| max_anchors: int = 2048 | |
| dropout: float = 0.0 | |
| class AnchorCompressor(nn.Module): | |
| """Compress local token spans into trainable anchor vectors.""" | |
| def __init__(self, d_model: int, anchor_stride: int): | |
| super().__init__() | |
| self.anchor_stride = anchor_stride | |
| self.score = nn.Linear(d_model, 1) | |
| self.mix = nn.Sequential( | |
| nn.LayerNorm(d_model), | |
| nn.Linear(d_model, 4 * d_model), | |
| nn.GELU(), | |
| nn.Linear(4 * d_model, d_model), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| bsz, seq, dim = x.shape | |
| pad = (-seq) % self.anchor_stride | |
| if pad: | |
| x = F.pad(x, (0, 0, 0, pad)) | |
| chunks = x.view(bsz, -1, self.anchor_stride, dim) | |
| weights = self.score(chunks).softmax(dim=2) | |
| pooled = (chunks * weights).sum(dim=2) | |
| return pooled + self.mix(pooled) | |
| class AnchorMemoryLayer(nn.Module): | |
| """Local-token stream reads from a bounded bank of learned anchors.""" | |
| def __init__(self, cfg: AnchorMemoryConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.compress = AnchorCompressor(cfg.d_model, cfg.anchor_stride) | |
| self.q_ln = nn.LayerNorm(cfg.d_model) | |
| self.mem_ln = nn.LayerNorm(cfg.d_model) | |
| self.read = nn.MultiheadAttention( | |
| cfg.d_model, | |
| cfg.heads, | |
| dropout=cfg.dropout, | |
| batch_first=True, | |
| ) | |
| self.gate = nn.Sequential(nn.Linear(2 * cfg.d_model, cfg.d_model), nn.Sigmoid()) | |
| self.out_ln = nn.LayerNorm(cfg.d_model) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| memory: torch.Tensor | None = None, | |
| *, | |
| detach_memory: bool = False, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| new_anchors = self.compress(x) | |
| if detach_memory: | |
| new_anchors = new_anchors.detach() | |
| if memory is None: | |
| bank = new_anchors | |
| else: | |
| bank = torch.cat([memory, new_anchors], dim=1) | |
| if bank.size(1) > self.cfg.max_anchors: | |
| bank = bank[:, -self.cfg.max_anchors :] | |
| recalled, _ = self.read(self.q_ln(x), self.mem_ln(bank), self.mem_ln(bank), need_weights=False) | |
| gate = self.gate(torch.cat([x, recalled], dim=-1)) | |
| mixed = x + gate * recalled | |
| return self.out_ln(mixed), bank | |
| def smoke_test() -> None: | |
| cfg = AnchorMemoryConfig(d_model=128, heads=8, anchor_stride=32, max_anchors=64) | |
| layer = AnchorMemoryLayer(cfg) | |
| x = torch.randn(2, 256, 128) | |
| y, memory = layer(x) | |
| assert y.shape == x.shape | |
| assert memory.shape == (2, 8, 128) | |
| y2, memory2 = layer(x, memory) | |
| assert y2.shape == x.shape | |
| assert memory2.shape == (2, 16, 128) | |
| print("anchor_memory smoke OK", y.shape, memory2.shape) | |
| # ===== END anchor_memory.py ===== | |
| # ===== BEGIN fused_ce.py ===== | |
| """Fused cross-entropy: streams over the VOCAB dimension (online-softmax) so the | |
| [N x V] logit matrix is NEVER materialized -- only [N x vchunk]. Custom backward | |
| recomputes softmax per vocab-chunk (grad = softmax - onehot). This is the | |
| DiffusionBlocks 'process in chunks, don't hold the whole thing' idea applied to | |
| the output head instead of network depth.""" | |
| import torch | |
| class FusedCE(torch.autograd.Function): | |
| def forward(ctx, h, W, tgt, vchunk=16384): | |
| with torch.cuda.amp.autocast(enabled=False): | |
| hf = h.float() | |
| Wf = W.float() | |
| N, d = h.shape | |
| V = W.shape[0] | |
| m = torch.full((N,), -1e30, device=h.device, dtype=torch.float32) | |
| s = torch.zeros(N, device=h.device, dtype=torch.float32) | |
| zt = torch.zeros(N, device=h.device, dtype=torch.float32) | |
| for c in range(0, V, vchunk): | |
| lg = hf @ Wf[c:c+vchunk].T # [N,vchunk] transient only | |
| cm = lg.max(1).values | |
| nm = torch.maximum(m, cm) | |
| s = s * torch.exp(m - nm) + torch.exp(lg - nm[:, None]).sum(1) | |
| m = nm | |
| ic = (tgt >= c) & (tgt < c+vchunk) | |
| if ic.any(): | |
| zt[ic] = lg[ic, tgt[ic] - c].float() | |
| lse = m + torch.log(s) | |
| ctx.save_for_backward(h, W, tgt, lse) | |
| ctx.vchunk = vchunk | |
| return (lse - zt).mean() | |
| def backward(ctx, go): | |
| h, W, tgt, lse = ctx.saved_tensors | |
| vc = ctx.vchunk | |
| N, d = h.shape | |
| V = W.shape[0] | |
| with torch.cuda.amp.autocast(enabled=False): | |
| hf = h.float() | |
| Wc_all = W.float() | |
| gh = torch.zeros_like(hf) | |
| gW = torch.zeros(W.shape, device=W.device, dtype=torch.float32) | |
| sc = float(go) / N | |
| for c in range(0, V, vc): | |
| Wc = Wc_all[c:c+vc] | |
| p = torch.exp(hf @ Wc.T - lse[:, None]) # softmax chunk [N,vchunk] | |
| ic = (tgt >= c) & (tgt < c+vc) | |
| if ic.any(): | |
| p[ic, tgt[ic] - c] -= 1.0 | |
| p *= sc | |
| gh += p @ Wc | |
| gW[c:c+vc] += p.T @ hf | |
| return gh.to(h.dtype), gW.to(W.dtype), None, None | |
| def fused_ce(h, W, tgt, vchunk=16384): | |
| return FusedCE.apply(h.reshape(-1, h.size(-1)), W, tgt.reshape(-1), vchunk) | |
| # ===== END fused_ce.py ===== | |
| # ===== BEGIN dblocks_train.py ===== | |
| """DiffusionBlocks training mode folded into AGILLM-4 (gated by --dblock). | |
| Block-wise EDM denoising on the real Encoder blocks, supervising AR + SAT(fixed+var) | |
| + NAT each step on ONE block, with grad-checkpointed layers and fused vocab-streaming | |
| CE. Reuses the live data stream / optimizer / checkpointing of nB300_agillm4. | |
| Lazy-imports nB300 inside functions to avoid a circular import. | |
| """ | |
| import math | |
| import random | |
| import time | |
| from collections import defaultdict | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint as _ck | |
| SD = 0.5 | |
| def _profile_active(state, args): | |
| limit = int(getattr(args, "profile_steps", 0) or 0) | |
| return limit > 0 and int(state.get("profile_n", 0)) < limit | |
| def _profile_add(state, name, seconds): | |
| if seconds is None: | |
| return | |
| prof = state.setdefault("profile_times", defaultdict(float)) | |
| prof[name] += float(seconds) | |
| def _profile_tic(enabled): | |
| if not enabled: | |
| return None | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| return time.perf_counter() | |
| def _profile_toc(state, name, start): | |
| if start is None: | |
| return | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| _profile_add(state, name, time.perf_counter() - start) | |
| def _profile_step_done(state, args): | |
| limit = int(getattr(args, "profile_steps", 0) or 0) | |
| if limit <= 0: | |
| return | |
| n_prev = int(state.get("profile_n", 0)) | |
| if n_prev >= limit: | |
| return | |
| state["profile_n"] = n_prev + 1 | |
| n = int(state["profile_n"]) | |
| log_every = max(1, int(getattr(args, "profile_log_every", 25) or 25)) | |
| if n % log_every != 0 and n != limit: | |
| return | |
| times = state.get("profile_times", {}) | |
| keys = [ | |
| "data_stream", "tensor", "setup", | |
| "ar_forward", "ar_ce", "ar_backward", | |
| "sat_forward", "sat_ce", "sat_backward", | |
| "nat_forward", "nat_ce", "nat_backward", | |
| "opt_step", "step_total", | |
| ] | |
| parts = [] | |
| for key in keys: | |
| val = float(times.get(key, 0.0)) * 1000.0 / max(1, n) | |
| if val > 0.01: | |
| parts.append(f"{key}={val:.2f}ms") | |
| print(f"[profile] n={n}/{limit} avg " + " ".join(parts), flush=True) | |
| def _cdf(x): | |
| return 0.5 * (1 + math.erf(x / math.sqrt(2))) | |
| def _ppf(p): | |
| return float(torch.erfinv(torch.tensor(2 * p - 1.0)) * math.sqrt(2)) | |
| def _block_sigmas(B, smin=0.002, smax=80.0, pm=-1.2, ps=1.2): | |
| a, b = _cdf((math.log(smin) - pm) / ps), _cdf((math.log(smax) - pm) / ps) | |
| return [float(np.exp(pm + ps * _ppf(a + (b - a) * (i / B)))) for i in range(B + 1)] | |
| def _edm_pre(s): | |
| s = s[:, None, None] | |
| return SD**2 / (s**2 + SD**2), s * SD / (s**2 + SD**2) ** 0.5, 1 / (s**2 + SD**2) ** 0.5 | |
| def _edm_w(s, wmax=5.0): | |
| return float(((s**2 + SD**2) / (s * SD) ** 2).clamp(max=wmax).mean()) | |
| _DBLOCK_ROUTER_EVENT_FEATURES = 10 | |
| _DBLOCK_ROUTER_HISTORY = 32 | |
| class _DblockLearnedRouter(nn.Module): | |
| # Transformer DBlock router conditioned on the network's running representation | |
| # plus a bounded route/outcome memory. Sequence = [CTX] + B block tokens + H | |
| # recent outcome tokens, so routing can learn from what the model is seeing now | |
| # and what the previous routing choices actually did to loss. | |
| def __init__(self, ctx_dim, d_model=64, heads=4, layers=2, feat_dim=6, n_blocks_max=64, history=_DBLOCK_ROUTER_HISTORY, event_dim=_DBLOCK_ROUTER_EVENT_FEATURES): | |
| super().__init__() | |
| d_model = max(16, int(d_model)) | |
| heads = max(1, int(heads)) | |
| if d_model % heads != 0: | |
| heads = 1 | |
| self.ctx_dim = int(ctx_dim) | |
| self.feat_dim = int(feat_dim) | |
| self.history = max(0, int(history)) | |
| self.event_dim = int(event_dim) | |
| self.block_emb = nn.Embedding(int(n_blocks_max), d_model) | |
| self.feat_proj = nn.Linear(int(feat_dim), d_model) | |
| self.ctx_proj = nn.Linear(int(ctx_dim), d_model) | |
| self.event_proj = nn.Linear(self.event_dim, d_model) | |
| self.kind_emb = nn.Embedding(3, d_model) | |
| self.event_pos = nn.Embedding(max(1, self.history), d_model) | |
| self.cls = nn.Parameter(torch.zeros(1, 1, d_model)) | |
| enc = nn.TransformerEncoderLayer( | |
| d_model=d_model, nhead=heads, dim_feedforward=max(32, d_model * 4), | |
| dropout=0.0, activation="gelu", batch_first=True, norm_first=True, | |
| ) | |
| self.encoder = nn.TransformerEncoder(enc, num_layers=max(1, int(layers))) | |
| self.ln = nn.LayerNorm(d_model) | |
| self.value = nn.Sequential( | |
| nn.LayerNorm(d_model * 2), | |
| nn.Linear(d_model * 2, d_model), | |
| nn.GELU(), | |
| nn.Linear(d_model, 1), | |
| ) | |
| nn.init.normal_(self.cls, std=0.02) | |
| def _fit_last_dim(x, dim): | |
| if x.size(-1) == dim: | |
| return x | |
| if x.size(-1) > dim: | |
| return x[..., :dim] | |
| return F.pad(x, (0, dim - x.size(-1))) | |
| def forward(self, block_ids, feats, ctx, history=None): | |
| feats = self._fit_last_dim(feats.float(), self.feat_dim) | |
| ctx = self._fit_last_dim(ctx.float(), self.ctx_dim) | |
| B = feats.size(1) | |
| bt = self.block_emb(block_ids.clamp(min=0, max=self.block_emb.num_embeddings - 1)) + self.feat_proj(feats) | |
| bt = bt + self.kind_emb(torch.ones(B, dtype=torch.long, device=feats.device)).unsqueeze(0) | |
| ctx_tok = self.cls + self.ctx_proj(ctx).unsqueeze(1) | |
| ctx_tok = ctx_tok + self.kind_emb(torch.zeros(1, dtype=torch.long, device=feats.device)).view(1, 1, -1) | |
| tokens = [ctx_tok, bt] | |
| if history is not None and self.history > 0: | |
| if not torch.is_tensor(history): | |
| history = torch.tensor(history, dtype=feats.dtype, device=feats.device) | |
| else: | |
| history = history.to(device=feats.device, dtype=feats.dtype) | |
| if history.dim() == 2: | |
| history = history.unsqueeze(0) | |
| if history.dim() == 3 and history.numel() > 0: | |
| if history.size(0) == 1 and feats.size(0) > 1: | |
| history = history.expand(feats.size(0), -1, -1) | |
| elif history.size(0) != feats.size(0): | |
| history = history[:1].expand(feats.size(0), -1, -1) | |
| if history.size(1) > self.history: | |
| history = history[:, -self.history :, :] | |
| history = self._fit_last_dim(history, self.event_dim) | |
| H = history.size(1) | |
| if H > 0: | |
| pos = torch.arange(H, dtype=torch.long, device=feats.device).clamp(max=max(0, self.history - 1)) | |
| kind = torch.full((H,), 2, dtype=torch.long, device=feats.device) | |
| ht = self.event_proj(history) + self.event_pos(pos).unsqueeze(0) + self.kind_emb(kind).unsqueeze(0) | |
| tokens.append(ht) | |
| h = self.ln(self.encoder(torch.cat(tokens, dim=1))) | |
| ctx_h = h[:, 0:1, :].expand(-1, B, -1) | |
| block_h = h[:, 1 : 1 + B, :] | |
| return self.value(torch.cat([block_h, ctx_h], dim=-1)).squeeze(-1) | |
| def _dblock_router_mode(args): | |
| return str(getattr(args, "dblock_router", "heuristic") or "heuristic").lower() | |
| def _dblock_router_enabled(args): | |
| return _dblock_router_mode(args) in {"transformer", "learned", "neural"} | |
| def _dblock_router_boot(state, args, ctx_dim=None): | |
| if not _dblock_router_enabled(args): | |
| return | |
| hidden = int(getattr(args, "dblock_router_hidden", 64) or 64) | |
| heads = int(getattr(args, "dblock_router_heads", 4) or 4) | |
| layers = int(getattr(args, "dblock_router_layers", 2) or 2) | |
| lr = float(getattr(args, "dblock_router_lr", 0.002) or 0.002) | |
| history = max(8, min(128, int(getattr(args, "dblock_router_history", _DBLOCK_ROUTER_HISTORY) or _DBLOCK_ROUTER_HISTORY))) | |
| cdim = int(ctx_dim or state.get("router_ctx_dim", 0) or 64) | |
| state["router_ctx_dim"] = cdim | |
| router = _DblockLearnedRouter(ctx_dim=cdim, d_model=hidden, heads=heads, layers=layers, history=history).to("cpu") | |
| state["router"] = router | |
| state["router_opt"] = torch.optim.AdamW(router.parameters(), lr=lr, weight_decay=1e-3) | |
| state["router_target_ema"] = None | |
| state["router_target_abs_ema"] = None | |
| state["router_train_loss"] = None | |
| state["router_last"] = None | |
| state["router_history"] = [] | |
| state["router_history_limit"] = history | |
| print( | |
| f"[dblock] learned_router=ctx_seq_transformer hidden={hidden} heads={heads} layers={layers} ctx_dim={cdim} history={history} lr={lr:g} " | |
| f"blend={float(getattr(args, 'dblock_router_blend', 0.35)):.2f} " | |
| f"ramp_steps={int(getattr(args, 'dblock_router_ramp_steps', 256) or 0)}", | |
| flush=True, | |
| ) | |
| def _dblock_router_features(state, args): | |
| B = int(state["B"]) | |
| step = int(state.get("step", 0)) | |
| counts = list(state.get("counts", [0 for _ in range(B)])) | |
| if len(counts) != B: | |
| counts = [0 for _ in range(B)] | |
| emas = list(state.get("loss_ema", [None for _ in range(B)])) | |
| if len(emas) != B: | |
| emas = [None for _ in range(B)] | |
| last_seen = list(state.get("last_seen", [-1 for _ in range(B)])) | |
| if len(last_seen) != B: | |
| last_seen = [-1 for _ in range(B)] | |
| bsig = list(state.get("bsig", _block_sigmas(B))) | |
| max_count = max(1, max(counts) if counts else 1) | |
| known = [float(x) for x in emas if x is not None and math.isfinite(float(x))] | |
| center = sum(known) / len(known) if known else 0.0 | |
| scale = (sum((x - center) ** 2 for x in known) / len(known)) ** 0.5 if len(known) > 1 else max(1.0, abs(center) * 0.05) | |
| scale = max(1e-3, scale) | |
| stale = [step - last_seen[i] if last_seen[i] >= 0 else step + 1 for i in range(B)] | |
| max_stale = int(getattr(args, "dblock_max_stale_steps", 64) or 0) | |
| stale_denom = float(max(1, max_stale if max_stale > 0 else max(stale) if stale else 1)) | |
| logs = [math.log(max(1e-9, float(x))) for x in bsig] | |
| log_min = min(logs) if logs else 0.0 | |
| log_span = max(1e-6, (max(logs) - log_min) if logs else 1.0) | |
| feats = [] | |
| for i in range(B): | |
| ema = emas[i] | |
| known_flag = 1.0 if ema is not None and math.isfinite(float(ema)) else 0.0 | |
| loss_z = 0.0 if not known_flag else max(-5.0, min(5.0, (float(ema) - center) / scale)) | |
| lo = logs[min(i, len(logs) - 1)] if logs else 0.0 | |
| hi = logs[min(i + 1, len(logs) - 1)] if logs else lo | |
| sig_mid = ((0.5 * (lo + hi)) - log_min) / log_span | |
| feats.append([ | |
| loss_z, known_flag, float(counts[i]) / float(max_count), | |
| max(0.0, float(max_count - counts[i]) / float(max_count)), | |
| min(1.0, max(0.0, float(stale[i]) / stale_denom)), float(sig_mid), | |
| ]) | |
| block_ids = torch.arange(B, dtype=torch.long).unsqueeze(0) | |
| ft = torch.tensor([feats], dtype=torch.float32) | |
| cdim = int(state.get("router_ctx_dim", 0) or 0) | |
| ctx = state.get("router_ctx") | |
| if torch.is_tensor(ctx) and cdim > 0 and ctx.numel() == cdim: | |
| cv = ctx.detach().reshape(1, cdim).float() | |
| else: | |
| cv = torch.zeros(1, max(1, cdim)) | |
| return block_ids, ft, cv | |
| def _dblock_router_clip(x, lo=-5.0, hi=5.0): | |
| try: | |
| x = float(x) | |
| except Exception: | |
| return 0.0 | |
| if not math.isfinite(x): | |
| return 0.0 | |
| return max(lo, min(hi, x)) | |
| def _dblock_router_history_features(state, args): | |
| limit = int(state.get("router_history_limit", getattr(args, "dblock_router_history", _DBLOCK_ROUTER_HISTORY)) or 0) | |
| limit = max(0, min(128, limit)) | |
| if limit <= 0: | |
| return torch.zeros((1, 0, _DBLOCK_ROUTER_EVENT_FEATURES), dtype=torch.float32) | |
| hist = list(state.get("router_history", []))[-limit:] | |
| if not hist: | |
| return torch.zeros((1, 0, _DBLOCK_ROUTER_EVENT_FEATURES), dtype=torch.float32) | |
| B = int(state["B"]) | |
| step = int(state.get("step", 0)) | |
| losses = [] | |
| for rec in hist: | |
| try: | |
| loss = float(rec.get("loss", 0.0)) | |
| except Exception: | |
| loss = 0.0 | |
| if math.isfinite(loss): | |
| losses.append(loss) | |
| center = sum(losses) / len(losses) if losses else 0.0 | |
| scale = (sum((x - center) ** 2 for x in losses) / len(losses)) ** 0.5 if len(losses) > 1 else max(1.0, abs(center) * 0.05) | |
| scale = max(1e-3, scale) | |
| rows = [] | |
| for rec in hist: | |
| rec_step = int(rec.get("step", -1)) | |
| block = max(0, min(B - 1, int(rec.get("block", 0)))) | |
| age = max(0, step - rec_step) | |
| try: | |
| rec_loss = float(rec.get("loss", center)) | |
| except Exception: | |
| rec_loss = center | |
| loss = _dblock_router_clip((rec_loss - center) / scale) | |
| rows.append([ | |
| float(block) / float(max(1, B - 1)), | |
| _dblock_router_clip(rec.get("target", 0.0)), | |
| loss, | |
| max(0.0, min(1.0, float(rec.get("count_norm", 0.0)))), | |
| max(0.0, min(1.0, float(rec.get("stale_norm", 0.0)))), | |
| min(1.0, math.log1p(age) / math.log1p(max(2, limit))), | |
| min(1.0, math.log1p(max(0, rec_step)) / math.log1p(10000.0)), | |
| 1.0 if float(rec.get("router_choice", 0.0)) > 0.0 else 0.0, | |
| max(0.0, min(1.0, float(rec.get("blend", 0.0)))), | |
| 1.0, | |
| ]) | |
| return torch.tensor([rows], dtype=torch.float32) | |
| def _dblock_router_append_history(state, args, bi, loss_float, target_val): | |
| limit = int(state.get("router_history_limit", getattr(args, "dblock_router_history", _DBLOCK_ROUTER_HISTORY)) or _DBLOCK_ROUTER_HISTORY) | |
| limit = max(0, min(128, limit)) | |
| if limit <= 0: | |
| return | |
| B = int(state["B"]) | |
| step = int(state.get("step", 0)) | |
| counts = list(state.get("counts", [0 for _ in range(B)])) | |
| if len(counts) != B: | |
| counts = [0 for _ in range(B)] | |
| last_seen = list(state.get("last_seen", [-1 for _ in range(B)])) | |
| if len(last_seen) != B: | |
| last_seen = [-1 for _ in range(B)] | |
| max_count = max(1, max(counts) if counts else 1) | |
| stale = step - last_seen[int(bi)] if 0 <= int(bi) < len(last_seen) and last_seen[int(bi)] >= 0 else step + 1 | |
| max_stale = int(getattr(args, "dblock_max_stale_steps", 64) or 0) | |
| stale_denom = float(max(1, max_stale if max_stale > 0 else stale)) | |
| route = state.get("router_last") | |
| router_choice = 0.0 | |
| blend = 0.0 | |
| if isinstance(route, dict): | |
| router_choice = 1.0 if int(route.get("choice", -1)) == int(bi) else 0.0 | |
| blend = float(route.get("blend", 0.0)) | |
| hist = state.setdefault("router_history", []) | |
| hist.append({ | |
| "step": int(step), | |
| "block": int(bi), | |
| "loss": float(loss_float), | |
| "target": float(target_val), | |
| "count_norm": float(counts[int(bi)]) / float(max_count) if 0 <= int(bi) < len(counts) else 0.0, | |
| "stale_norm": min(1.0, max(0.0, float(stale) / stale_denom)), | |
| "router_choice": router_choice, | |
| "blend": blend, | |
| }) | |
| if len(hist) > limit: | |
| del hist[:-limit] | |
| def _dblock_router_norm(xs): | |
| vals = [0.0 if not math.isfinite(float(x)) else float(x) for x in xs] | |
| if not vals: | |
| return vals | |
| mean = sum(vals) / len(vals) | |
| scale = max(1e-6, (sum((x - mean) ** 2 for x in vals) / len(vals)) ** 0.5) | |
| return [(x - mean) / scale for x in vals] | |
| def _dblock_router_choose(state, args, heuristic_scores): | |
| state["router_last"] = None | |
| if not _dblock_router_enabled(args): | |
| return None | |
| router = state.get("router") | |
| if router is None: | |
| return None | |
| B = int(state["B"]) | |
| step = int(state.get("step", 0)) | |
| warmup = int(getattr(args, "dblock_warmup_steps", max(8, B * 2))) | |
| ramp_steps = int(getattr(args, "dblock_router_ramp_steps", 256) or 0) | |
| blend_base = max(0.0, min(1.0, float(getattr(args, "dblock_router_blend", 0.35) or 0.0))) | |
| if step < warmup or blend_base <= 0.0: | |
| return None | |
| ramp = 1.0 if ramp_steps <= 0 else min(1.0, max(0.0, float(step - warmup) / float(ramp_steps))) | |
| blend = blend_base * ramp | |
| if blend <= 1e-6: | |
| return None | |
| history_features = _dblock_router_history_features(state, args) | |
| with torch.no_grad(): | |
| router.eval() | |
| pred = router(*_dblock_router_features(state, args), history=history_features)[0].detach().cpu().tolist() | |
| h = _dblock_router_norm(heuristic_scores) | |
| q = _dblock_router_norm(pred) | |
| if len(h) != B or len(q) != B: | |
| return None | |
| counts = state.get("counts", [0 for _ in range(B)]) | |
| combined = [(1.0 - blend) * h[i] + blend * q[i] for i in range(B)] | |
| choice = max(range(B), key=lambda i: (combined[i], -counts[i], -i)) | |
| state["router_last"] = { | |
| "mode": "ctx_seq_transformer", | |
| "choice": int(choice), | |
| "blend": float(blend), | |
| "history": int(history_features.size(1)), | |
| "pred": [float(x) for x in pred], | |
| } | |
| return choice | |
| def _dblock_router_update(state, args, bi, loss_value): | |
| if not _dblock_router_enabled(args): | |
| return | |
| router, opt = state.get("router"), state.get("router_opt") | |
| if router is None or opt is None: | |
| return | |
| try: | |
| loss_float = float(loss_value) | |
| except Exception: | |
| return | |
| if not math.isfinite(loss_float): | |
| return | |
| baseline = state.get("router_target_ema") | |
| scale = state.get("router_target_abs_ema") | |
| if baseline is None or not math.isfinite(float(baseline)): | |
| baseline = loss_float | |
| if scale is None or not math.isfinite(float(scale)) or float(scale) < 1e-3: | |
| scale = max(1.0, abs(loss_float) * 0.05) | |
| target_val = max(-5.0, min(5.0, (loss_float - float(baseline)) / max(1e-3, float(scale)))) | |
| router.train() | |
| pred = router(*_dblock_router_features(state, args), history=_dblock_router_history_features(state, args))[0, int(bi)] | |
| fit_loss = F.smooth_l1_loss(pred, pred.detach().new_tensor(target_val)) | |
| opt.zero_grad(set_to_none=True) | |
| fit_loss.backward() | |
| nn.utils.clip_grad_norm_(router.parameters(), 1.0) | |
| opt.step() | |
| diff = abs(loss_float - float(baseline)) | |
| state["router_target_ema"] = 0.98 * float(baseline) + 0.02 * loss_float | |
| state["router_target_abs_ema"] = 0.98 * float(scale) + 0.02 * max(1e-3, diff) | |
| state["router_train_loss"] = float(fit_loss.detach().cpu()) | |
| _dblock_router_append_history(state, args, bi, loss_float, target_val) | |
| def _dblock_get_candidates(L): | |
| c = [] | |
| # 1. Uniform candidates for b in [2, 3, 4, 6] | |
| for b in [2, 3, 4, 6]: | |
| per = max(1, L // b) | |
| asg = [list(range(i * per, (i + 1) * per)) for i in range(b)] | |
| asg[-1] = list(range((b - 1) * per, L)) | |
| c.append((b, asg, f"Uniform-{b}")) | |
| # 2. Non-uniform candidates for B=3 | |
| # Middle-heavy (e.g. 25%, 50%, 25%) | |
| m_h = [max(1, L // 4), max(1, L // 2)] | |
| m_h.append(L - sum(m_h)) | |
| asg = [] | |
| curr = 0 | |
| for size in m_h: | |
| asg.append(list(range(curr, curr + size))) | |
| curr += size | |
| c.append((3, asg, "Middle-Heavy-3")) | |
| # End-heavy (e.g. 20%, 35%, 45%) | |
| e_h = [max(1, int(L * 0.20)), max(1, int(L * 0.35))] | |
| e_h.append(L - sum(e_h)) | |
| asg = [] | |
| curr = 0 | |
| for size in e_h: | |
| asg.append(list(range(curr, curr + size))) | |
| curr += size | |
| c.append((3, asg, "End-Heavy-3")) | |
| # Start-heavy (e.g. 45%, 35%, 20%) | |
| s_h = [max(1, int(L * 0.45)), max(1, int(L * 0.35))] | |
| s_h.append(L - sum(s_h)) | |
| asg = [] | |
| curr = 0 | |
| for size in s_h: | |
| asg.append(list(range(curr, curr + size))) | |
| curr += size | |
| c.append((3, asg, "Start-Heavy-3")) | |
| # 3. Non-uniform candidates for B=4 | |
| # Middle-heavy (e.g. 20%, 30%, 30%, 20%) | |
| m_h4 = [max(1, int(L * 0.20)), max(1, int(L * 0.30)), max(1, int(L * 0.30))] | |
| m_h4.append(L - sum(m_h4)) | |
| asg = [] | |
| curr = 0 | |
| for size in m_h4: | |
| asg.append(list(range(curr, curr + size))) | |
| curr += size | |
| c.append((4, asg, "Middle-Heavy-4")) | |
| # End-heavy (e.g. 15%, 25%, 30%, 30%) | |
| e_h4 = [max(1, int(L * 0.15)), max(1, int(L * 0.25)), max(1, int(L * 0.30))] | |
| e_h4.append(L - sum(e_h4)) | |
| asg = [] | |
| curr = 0 | |
| for size in e_h4: | |
| asg.append(list(range(curr, curr + size))) | |
| curr += size | |
| c.append((4, asg, "End-Heavy-4")) | |
| return c | |
| def _dblock_init(core, args): | |
| L = len(core.blocks) | |
| auto_search = getattr(args, "auto_dblock_search", False) | |
| if auto_search: | |
| candidates = _dblock_get_candidates(L) | |
| print(f"[dblock] Auto Search enabled with {len(candidates)} candidates.") | |
| B, asg, name = candidates[0] | |
| state = { | |
| "auto_search": True, | |
| "candidates": candidates, | |
| "candidate_idx": 0, | |
| "search_step": 0, | |
| "search_interval": 20, | |
| "scores": [], | |
| } | |
| else: | |
| B = int(getattr(args, "dblock_blocks", 4)) | |
| sp = max(1, L // B) | |
| asg = [list(range(i * sp, (i + 1) * sp)) for i in range(B)] | |
| asg[-1] = list(range((B - 1) * sp, L)) | |
| state = {"auto_search": False} | |
| bsig = _block_sigmas(B) | |
| schedule = getattr(args, "dblock_schedule", "loss_balanced") | |
| print(f"[dblock] DiffusionBlocks mode: {L} layers -> {B} blocks {asg}") | |
| print(f"[dblock] schedule={schedule} sigma boundaries: {[round(x, 3) for x in bsig]}") | |
| state.update({ | |
| "B": B, | |
| "assign": asg, | |
| "bsig": bsig, | |
| "step": 0, | |
| "counts": [0 for _ in range(B)], | |
| "loss_ema": [None for _ in range(B)], | |
| "last_seen": [-1 for _ in range(B)], | |
| }) | |
| _dblock_router_boot(state, args, ctx_dim=int(getattr(core.emb, "embedding_dim", 0)) or None) | |
| return state | |
| def _choose_block(state, args): | |
| if not state.get("auto_search", False) and state.get("step", 0) % 100 == 0: | |
| try: | |
| cfg = get_hot_config() | |
| if "dblock_blocks" in cfg: | |
| new_B = int(cfg["dblock_blocks"]) | |
| if new_B != state.get("B"): | |
| L = sum(len(x) for x in state["assign"]) if "assign" in state else 28 | |
| new_sp = max(1, L // new_B) | |
| new_asg = [list(range(i * new_sp, (i + 1) * new_sp)) for i in range(new_B)] | |
| new_asg[-1] = list(range((new_B - 1) * new_sp, L)) | |
| print(f"[dblock] Dynamically adjusting block configuration from hot_config: B={state['B']} -> {new_B}, assign={new_asg}", flush=True) | |
| state["B"] = new_B | |
| state["assign"] = new_asg | |
| state["bsig"] = _block_sigmas(new_B) | |
| state["counts"] = [0] * new_B | |
| state["loss_ema"] = [None] * new_B | |
| state["last_seen"] = [-1] * new_B | |
| except Exception as e: | |
| print(f"[dblock] Error reloading hot_config in _choose_block: {e}", flush=True) | |
| if state.get("auto_search", False) and state["candidate_idx"] < len(state["candidates"]): | |
| state["search_step"] += 1 | |
| if "search_start_time" not in state: | |
| state["search_start_time"] = time.perf_counter() | |
| state["search_tokens"] = 0 | |
| if state["search_step"] >= state["search_interval"]: | |
| valid_emas = [e for e in state["loss_ema"] if e is not None] | |
| avg_loss = sum(valid_emas) / max(1, len(valid_emas)) if valid_emas else float('inf') | |
| elapsed = time.perf_counter() - state["search_start_time"] | |
| tokens = state.get("search_tokens", 0) | |
| tokps = tokens / max(1e-9, elapsed) | |
| cand = state["candidates"][state["candidate_idx"]] | |
| cand_name = cand[2] if len(cand) > 2 else f"Candidate-{state['candidate_idx']}" | |
| state["scores"].append({ | |
| "idx": state["candidate_idx"], | |
| "B": state["B"], | |
| "assign": state["assign"], | |
| "name": cand_name, | |
| "loss": avg_loss, | |
| "tokps": tokps | |
| }) | |
| print(f"[dblock] Candidate {state['candidate_idx']} ({cand_name}) complete: loss={avg_loss:.4f} speed={tokps:.1f} tok/s", flush=True) | |
| state["candidate_idx"] += 1 | |
| state["search_step"] = 0 | |
| if "search_start_time" in state: | |
| del state["search_start_time"] | |
| state["search_tokens"] = 0 | |
| if state["candidate_idx"] < len(state["candidates"]): | |
| B, asg, cand_name = state["candidates"][state["candidate_idx"]] | |
| state["B"] = B | |
| state["assign"] = asg | |
| state["bsig"] = _block_sigmas(B) | |
| state["counts"] = [0] * B | |
| state["loss_ema"] = [None] * B | |
| state["last_seen"] = [-1] * B | |
| print(f"[dblock] Switched to candidate {state['candidate_idx']} ({cand_name}): {B} blocks {asg}", flush=True) | |
| else: | |
| # Select the candidate with highest speed/loss utility | |
| best_cand = None | |
| best_utility = -1.0 | |
| for score_entry in state["scores"]: | |
| loss = score_entry["loss"] | |
| tokps = score_entry["tokps"] | |
| utility = tokps / max(1e-3, loss) | |
| score_entry["utility"] = utility | |
| if utility > best_utility: | |
| best_utility = utility | |
| best_cand = score_entry | |
| B = best_cand["B"] | |
| asg = best_cand["assign"] | |
| state["B"] = B | |
| state["assign"] = asg | |
| state["bsig"] = _block_sigmas(B) | |
| state["auto_search"] = False | |
| print(f"[dblock] Search complete. Locked in best candidate {best_cand['name']} (Utility={best_utility:.2f}, Loss={best_cand['loss']:.4f}, Speed={best_cand['tokps']:.1f} tok/s): {B} blocks {asg}", flush=True) | |
| B = state["B"] | |
| schedule = str(getattr(args, "dblock_schedule", "loss_balanced") or "loss_balanced").lower() | |
| step = int(state.get("step", 0)) | |
| counts = state.setdefault("counts", [0 for _ in range(B)]) | |
| if len(counts) != B: | |
| counts[:] = [0 for _ in range(B)] | |
| emas = state.setdefault("loss_ema", [None for _ in range(B)]) | |
| if len(emas) != B: | |
| emas[:] = [None for _ in range(B)] | |
| last_seen = state.setdefault("last_seen", [-1 for _ in range(B)]) | |
| if len(last_seen) != B: | |
| last_seen[:] = [-1 for _ in range(B)] | |
| state["router_last"] = None | |
| if schedule == "random": | |
| return random.randrange(B) | |
| if schedule == "roundrobin": | |
| return step % B | |
| explore = max(0.0, min(1.0, float(getattr(args, "dblock_explore", 0.05)))) | |
| warmup = int(getattr(args, "dblock_warmup_steps", max(8, B * 2))) | |
| def least_trained(): | |
| return min(range(B), key=lambda i: (counts[i], last_seen[i], i)) | |
| if step < warmup or any(c == 0 for c in counts): | |
| return least_trained() | |
| max_stale = int(getattr(args, "dblock_max_stale_steps", 64) or 0) | |
| stale = [step - last_seen[i] if last_seen[i] >= 0 else step + 1 for i in range(B)] | |
| if max_stale > 0 and max(stale) >= max_stale: | |
| return max(range(B), key=lambda i: (stale[i], -counts[i], -i)) | |
| max_count = max(counts) if counts else 0 | |
| min_count = min(counts) if counts else 0 | |
| max_skew = float(getattr(args, "dblock_max_count_skew", 1.35) or 0.0) | |
| if max_skew > 1.0 and min_count > 0 and (max_count / max(1, min_count)) > max_skew: | |
| return least_trained() | |
| if explore > 0.0 and random.random() < explore: | |
| return least_trained() | |
| stale_bonus = float(getattr(args, "dblock_stale_bonus", 0.35) or 0.0) | |
| undertrain_bonus = float(getattr(args, "dblock_undertrain_bonus", 0.25) or 0.0) | |
| stale_denom = float(max(1, max_stale if max_stale > 0 else max(stale) if stale else 1)) | |
| count_denom = float(max(1, max_count)) | |
| def score(i): | |
| loss_score = -1.0 if emas[i] is None else float(emas[i]) | |
| stale_score = stale_bonus * min(1.0, max(0.0, stale[i] / stale_denom)) | |
| undertrain_score = undertrain_bonus * max(0.0, (max_count - counts[i]) / count_denom) | |
| return (loss_score + stale_score + undertrain_score, -counts[i], stale[i], -i) | |
| heuristic_scores = [float(score(i)[0]) for i in range(B)] | |
| heuristic_choice = max(range(B), key=score) | |
| learned_choice = _dblock_router_choose(state, args, heuristic_scores) | |
| return heuristic_choice if learned_choice is None else learned_choice | |
| def _sample_sigma(ids, lo, hi, args, state): | |
| cur_step = int(state.get("step", 0)) | |
| curriculum = int(getattr(args, "dblock_sigma_curriculum_steps", 0)) | |
| if curriculum > 0: | |
| frac = min(1.0, max(0.05, (cur_step + 1) / float(curriculum))) | |
| hi = lo * ((hi / max(lo, 1e-8)) ** frac) | |
| sig_np = np.exp( | |
| np.random.uniform( | |
| math.log(max(lo, 1e-4)), | |
| math.log(max(hi, lo + 1e-4)), | |
| ids.size(0), | |
| ).astype("float32") | |
| ) | |
| return torch.from_numpy(sig_np).to(ids.device) | |
| def _maybe_log( | |
| state, | |
| args, | |
| bi, | |
| layers, | |
| ar_val, | |
| sat_val, | |
| nat_val, | |
| total_val, | |
| peak_alloc, | |
| peak_reserved, | |
| objective=None, | |
| raw_avg=None, | |
| raw_total=None, | |
| edm_weight=None, | |
| ): | |
| log_every = int(getattr(args, "dblock_log_every", 50)) | |
| step = int(state.get("step", 0)) | |
| if log_every <= 0 or step % log_every != 0: | |
| return | |
| counts_list = state.get("counts", []) | |
| last_seen = state.get("last_seen", [-1 for _ in counts_list]) | |
| counts = ",".join(str(x) for x in counts_list) | |
| emas = ",".join("nan" if x is None else f"{x:.2f}" for x in state.get("loss_ema", [])) | |
| stale = ",".join(str(max(0, step - int(last_seen[i]))) for i in range(min(len(counts_list), len(last_seen)))) | |
| mem = "" | |
| if peak_alloc is not None: | |
| mem = f" peak_alloc={peak_alloc:.2f}GB peak_reserved={peak_reserved:.2f}GB" | |
| display = float(raw_avg) if raw_avg is not None and math.isfinite(float(raw_avg)) else float(total_val) | |
| raw_part = "" | |
| if raw_total is not None: | |
| raw_part += f" raw_sum={float(raw_total):.3f}" | |
| if edm_weight is not None: | |
| raw_part += f" edm_w={float(edm_weight):.3f}" | |
| route = state.get("router_last") | |
| if isinstance(route, dict): | |
| pred = ",".join(f"{float(x):.2f}" for x in route.get("pred", [])) | |
| hist = route.get("history") | |
| hist_part = "" if hist is None else f" hist={int(hist)}" | |
| raw_part += f" router={route.get('mode', 'none')} blend={float(route.get('blend', 0.0)):.2f}{hist_part} pred=[{pred}]" | |
| rloss = state.get("router_train_loss") | |
| if rloss is not None: | |
| raw_part += f" router_fit={float(rloss):.3f}" | |
| print( | |
| f"[dblock] step={step} block={bi} obj={objective or 'mixed'} layers={layers} " | |
| f"loss={display:.3f} weighted={total_val:.3f} ar={ar_val:.3f} sat={sat_val:.3f} nat={nat_val:.3f}" | |
| f"{raw_part} counts=[{counts}] ema=[{emas}] stale=[{stale}]{mem}", | |
| flush=True, | |
| ) | |
| def _update_stats(state, bi, loss_value, args=None): | |
| if args is not None: | |
| _dblock_router_update(state, args, bi, loss_value) | |
| B = state["B"] | |
| counts = state.setdefault("counts", [0 for _ in range(B)]) | |
| emas = state.setdefault("loss_ema", [None for _ in range(B)]) | |
| last_seen = state.setdefault("last_seen", [-1 for _ in range(B)]) | |
| if len(last_seen) != B: | |
| last_seen[:] = [-1 for _ in range(B)] | |
| counts[bi] += 1 | |
| last_seen[bi] = int(state.get("step", 0)) | |
| prev = emas[bi] | |
| beta = 0.96 | |
| emas[bi] = float(loss_value) if prev is None else beta * float(prev) + (1.0 - beta) * float(loss_value) | |
| state["step"] = int(state.get("step", 0)) + 1 | |
| def _activation_offload_enabled(args): | |
| return bool(getattr(args, "dblock_activation_offload", False)) and torch.cuda.is_available() | |
| def _activation_offload_hooks(args): | |
| min_bytes = int(float(getattr(args, "dblock_activation_offload_min_mb", 1.0) or 1.0) * 1024 * 1024) | |
| def pack(t): | |
| if not torch.is_tensor(t) or not t.is_cuda or not t.is_floating_point() or t.numel() * t.element_size() < min_bytes: | |
| return t | |
| return ("cpu_offload", t.device, t.detach().to("cpu", non_blocking=True)) | |
| def unpack(x): | |
| if isinstance(x, tuple) and len(x) == 3 and x[0] == "cpu_offload": | |
| _, dev, cpu_t = x | |
| return cpu_t.to(dev, non_blocking=True) | |
| return x | |
| return torch.autograd.graph.saved_tensors_hooks(pack, unpack) | |
| def _dblock_sublayer_base_mode(args): | |
| mode = str(getattr(args, "dblock_sublayer_mode", "off") or "off").strip().lower().replace("-", "_") | |
| if mode in {"none", "disabled"}: | |
| return "off" | |
| return mode | |
| def _dblock_sublayer_mode_for_layer(args, state, block_idx, layer_pos): | |
| mode = _dblock_sublayer_base_mode(args) | |
| if mode == "split_alt": | |
| step = int((state or {}).get("step", 0)) | |
| return "attn_only" if ((step + int(block_idx) + int(layer_pos)) % 2 == 0) else "ffn_only" | |
| if mode == "cycle": | |
| step = int((state or {}).get("step", 0)) | |
| return ("full", "ffn_only", "attn_only")[(step + int(block_idx) + int(layer_pos)) % 3] | |
| return mode | |
| def _run_block_forward(block, x, mask, sublayer_mode="off"): | |
| mode = str(sublayer_mode or "off").strip().lower().replace("-", "_") | |
| if mode in {"off", "full"}: | |
| return block(x, mask) | |
| if mode == "attn_only": | |
| n = x.size(1) | |
| return x + block.mha(block.ln1(x), mask, rel_bias_tokens=n) | |
| if mode == "ffn_only": | |
| return x + block.ff(block.ln2(x)) | |
| raise ValueError(f"unknown DBlock sublayer mode: {sublayer_mode}") | |
| def _run_block(block, x, mask, use_checkpoint, args=None, sublayer_mode="off"): | |
| if use_checkpoint: | |
| return _ck.checkpoint(lambda y, block=block, mode=sublayer_mode: _run_block_forward(block, y, mask, mode), x, use_reentrant=False) | |
| if args is not None and _activation_offload_enabled(args): | |
| with _activation_offload_hooks(args): | |
| return _run_block_forward(block, x, mask, sublayer_mode) | |
| return _run_block_forward(block, x, mask, sublayer_mode) | |
| def _dblock_checkpoint_this_layer(args, base_enabled, layer_pos, layer_count=None): | |
| if not base_enabled: | |
| return False | |
| pos = int(layer_pos) | |
| count = int(layer_count or 0) | |
| skip_tail = max(0, int(getattr(args, "dblock_checkpoint_skip_tail", 0) or 0)) | |
| if skip_tail > 0 and count > 0 and pos >= max(0, count - skip_tail): | |
| return False | |
| stride = int(getattr(args, "dblock_checkpoint_stride", 1) or 1) | |
| if stride <= 0: | |
| return False | |
| if stride == 1: | |
| return True | |
| return (pos % stride) == 0 | |
| def _sample_token_loss_inputs(hidden, targets, max_tokens): | |
| max_tokens = int(max_tokens or 0) | |
| if max_tokens <= 0: | |
| return hidden.contiguous(), targets.contiguous(), int(targets.numel()), int(targets.numel()) | |
| flat_targets = targets.reshape(-1) | |
| total = int(flat_targets.numel()) | |
| if total <= max_tokens: | |
| return hidden.contiguous(), targets.contiguous(), total, total | |
| # With-replacement sampling avoids building a full randperm each step; the sampled | |
| # mean remains an unbiased estimator of the dense token CE mean. | |
| idx = torch.randint(total, (max_tokens,), device=targets.device) | |
| flat_hidden = hidden.reshape(total, hidden.size(-1)) | |
| return flat_hidden.index_select(0, idx).contiguous(), flat_targets.index_select(0, idx).contiguous(), int(max_tokens), total | |
| def _choose_objectives(state, args, ar_weight, sat_weight, nat_weight, do_sat_periodic, do_nat_periodic): | |
| mode = str(getattr(args, "dblock_objective_mode", "periodic") or "periodic").lower() | |
| if mode != "stochastic": | |
| return ar_weight > 0.0, sat_weight > 0.0 and do_sat_periodic, nat_weight > 0.0 and do_nat_periodic, "periodic" | |
| choices = [] | |
| probs = [] | |
| if ar_weight > 0.0: | |
| choices.append("ar") | |
| probs.append(max(0.0, float(getattr(args, "dblock_ar_prob", 0.80)))) | |
| if sat_weight > 0.0 and not getattr(args, "ar_only", False): | |
| choices.append("sat") | |
| probs.append(max(0.0, float(getattr(args, "dblock_sat_prob", 0.10)))) | |
| if nat_weight > 0.0 and not getattr(args, "ar_only", False): | |
| choices.append("nat") | |
| probs.append(max(0.0, float(getattr(args, "dblock_nat_prob", 0.10)))) | |
| if not choices: | |
| return False, False, False, "none" | |
| total = sum(probs) | |
| if total <= 0.0: | |
| probs = [1.0 / len(choices) for _ in choices] | |
| else: | |
| probs = [p / total for p in probs] | |
| picked = random.choices(choices, weights=probs, k=1)[0] | |
| return picked == "ar", picked == "sat", picked == "nat", picked | |
| def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state): | |
| import nB300_agillm4 as M | |
| if state is not None and state.get("auto_search", False): | |
| state["search_tokens"] = state.get("search_tokens", 0) + ids.numel() | |
| prof = _profile_active(state, args) | |
| _step_t = _profile_tic(prof) | |
| if torch.cuda.is_available(): | |
| torch.cuda.reset_peak_memory_stats() | |
| _setup_t = _profile_tic(prof) | |
| B = state["B"] | |
| asg = state["assign"] | |
| bs = state["bsig"] | |
| T = ids.size(1) | |
| use_layer_checkpoint = bool(getattr(args, "grad_checkpoint", False)) | |
| if _dblock_router_enabled(args): | |
| with torch.no_grad(): | |
| _rc_emb = core.emb(ids) | |
| state["router_ctx"] = _rc_emb.mean(dim=(0, 1)).detach().float().to("cpu") | |
| del _rc_emb | |
| bi = _choose_block(state, args) | |
| lo, hi = sorted([bs[bi], bs[bi + 1]]) | |
| layers = asg[bi] | |
| sig = _sample_sigma(ids, lo, hi, args, state) | |
| cs, co, ci = _edm_pre(sig) | |
| w = _edm_w(sig, float(getattr(args, "dblock_edm_wmax", 5.0))) | |
| SATB = M.SAT_BLOCK | |
| ar_weight = float(getattr(args, "dblock_ar_weight", 1.0)) | |
| sat_weight = float(getattr(args, "dblock_sat_weight", 1.0)) | |
| nat_weight = float(getattr(args, "dblock_nat_weight", 1.0)) * float(getattr(args, "nat_loss_weight", 1.0)) | |
| do_sat_periodic = (not getattr(args, "ar_only", False)) and ( | |
| int(getattr(args, "sat_every", 1)) <= 1 or ((int(state.get("step", 0)) + 1) % int(getattr(args, "sat_every", 1)) == 0) | |
| ) | |
| do_nat_periodic = ( | |
| nat_h is not None | |
| and (not getattr(args, "ar_only", False)) | |
| and int(getattr(args, "nat_every", 1)) > 0 | |
| and ( | |
| int(getattr(args, "nat_every", 1)) <= 1 | |
| or ((int(state.get("step", 0)) + 1) % int(getattr(args, "nat_every", 1)) == 0) | |
| ) | |
| ) | |
| run_ar, run_sat, run_nat, objective = _choose_objectives( | |
| state, args, ar_weight, sat_weight, nat_weight, do_sat_periodic, do_nat_periodic | |
| ) | |
| _profile_toc(state, "setup", _setup_t) | |
| ar_val = 0.0 | |
| sat_val = 0.0 | |
| nat_val = 0.0 | |
| ar_raw_val = 0.0 | |
| sat_raw_val = 0.0 | |
| nat_raw_val = 0.0 | |
| if run_ar: | |
| causal = M.causal_mask(T, structured=M.use_structured_masks(args)) | |
| _t = _profile_tic(prof) | |
| with M.amp(args.amp): | |
| emb = core.emb(ids) | |
| zt = emb + sig[:, None, None] * torch.randn_like(emb) | |
| h = ci * zt | |
| for lpos, li in enumerate(layers): | |
| mode = _dblock_sublayer_mode_for_layer(args, state, bi, lpos) | |
| h = _run_block(core.blocks[li], h, causal, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos, len(layers)), args, mode) | |
| Dn = core.ln(cs * zt + co * h) | |
| _profile_toc(state, "ar_forward", _t) | |
| _t = _profile_tic(prof) | |
| ar_hidden, ar_targets, ar_used, ar_total = _sample_token_loss_inputs( | |
| Dn[:, :-1], ids[:, 1:], int(getattr(args, "dblock_ar_loss_tokens", 0)) | |
| ) | |
| ar_raw = fused_ce(ar_hidden, ar_h.proj.weight, ar_targets) | |
| ar_raw_val = float(ar_raw.detach()) | |
| ar = ar_weight * w * ar_raw | |
| ar_val = float(ar.detach()) | |
| _profile_toc(state, "ar_ce", _t) | |
| _t = _profile_tic(prof) | |
| _aux = _collect_moe_aux(core, getattr(args,'moe_aux_coef',0.0), getattr(args,'moe_z_coef',0.0)) | |
| if torch.is_tensor(_aux): | |
| ar = ar + _aux.to(ar.dtype) | |
| scaler.scale(ar).backward() | |
| _profile_toc(state, "ar_backward", _t) | |
| del causal, emb, zt, h, Dn, ar_hidden, ar_targets, ar_raw, ar, ar_used, ar_total | |
| if run_sat: | |
| smask = M.sat_mask(T, structured=M.use_structured_masks(args)) | |
| _t = _profile_tic(prof) | |
| with M.amp(args.amp): | |
| emb2 = core.emb(ids) | |
| zt2 = emb2 + sig[:, None, None] * torch.randn_like(emb2) | |
| h2 = ci * zt2 | |
| for lpos, li in enumerate(layers): | |
| mode = _dblock_sublayer_mode_for_layer(args, state, bi, lpos) | |
| h2 = _run_block(core.blocks[li], h2, smask, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos, len(layers)), args, mode) | |
| Ds = core.ln(cs * zt2 + co * h2) | |
| _profile_toc(state, "sat_forward", _t) | |
| _t = _profile_tic(prof) | |
| # SAT decode uses the latest SAT_BLOCK hidden states to emit the next | |
| # SAT_BLOCK tokens. Train that contract densely across the context. | |
| sat_ctx = Ds[:, :-SATB] | |
| sat_tgt = ids[:, SATB:] | |
| if sat_ctx.size(1) == 0 or sat_ctx.size(1) != sat_tgt.size(1): | |
| sat_ctx = Ds[:, :-1] | |
| sat_tgt = ids[:, 1:] | |
| sat_hidden, sat_targets, sat_used, sat_total = _sample_token_loss_inputs( | |
| sat_ctx, sat_tgt, int(getattr(args, "dblock_sat_loss_tokens", 0)) | |
| ) | |
| sat_gate_ctx = sat_ctx[:, ::SATB] | |
| with M.amp(args.amp): | |
| satf = fused_ce(sat_hidden, sat_h.proj.weight, sat_targets) | |
| satv = ( | |
| M.EMIT_LAMBDA | |
| * F.cross_entropy( | |
| sat_h.gate(sat_gate_ctx.reshape(-1, sat_gate_ctx.size(-1)).float()), | |
| torch.ones(sat_gate_ctx.numel() // sat_gate_ctx.size(-1), dtype=torch.long, device=ids.device), | |
| ) | |
| if sat_h.gate is not None and sat_gate_ctx.size(1) > 0 | |
| else 0.0 | |
| ) | |
| sat_raw = satf + satv | |
| sat_raw_val = float(sat_raw.detach()) | |
| sat = sat_weight * w * sat_raw | |
| _profile_toc(state, "sat_ce", _t) | |
| sat_val = float(sat.detach()) | |
| _t = _profile_tic(prof) | |
| _aux = _collect_moe_aux(core, getattr(args,'moe_aux_coef',0.0), getattr(args,'moe_z_coef',0.0)) | |
| if torch.is_tensor(_aux): | |
| sat = sat + _aux.to(sat.dtype) | |
| scaler.scale(sat).backward() | |
| _profile_toc(state, "sat_backward", _t) | |
| del smask, emb2, zt2, h2, Ds, sat_hidden, sat_targets, sat_gate_ctx, satf, satv, sat_raw, sat | |
| if run_nat: | |
| ratio = min(max(float(getattr(args, "nat_mask_ratio", 0.5)), 0.05), 0.95) | |
| nat_mode = str(getattr(args, "dblock_nat_embed_noise_mode", "off") or "off").strip().lower() | |
| nat_noise_scale = max(0.0, float(getattr(args, "dblock_nat_embed_noise_scale", 1.0) or 1.0)) | |
| nat_ids = M._nat_ids_for_training(ids, int(getattr(args, "nat_max_tokens", 0))) | |
| _t = _profile_tic(prof) | |
| with M.amp(args.amp): | |
| nat_in = nat_ids.clone() | |
| m = torch.rand(nat_ids.shape, device=nat_ids.device) < ratio | |
| if not bool(m.any()): | |
| m[..., -1] = True | |
| if nat_mode in {"visible", "mask_plus_noise"}: | |
| clean_hn = core.emb(nat_ids) | |
| if nat_mode == "mask_plus_noise": | |
| nat_in[m] = M.BLANK | |
| hn = core.emb(nat_in) | |
| else: | |
| hn = clean_hn.clone() | |
| nat_noise = sig[:, None, None].to(clean_hn.dtype) * nat_noise_scale * torch.randn_like(clean_hn) | |
| hn = hn.clone() | |
| hn[m] = (clean_hn + nat_noise)[m] | |
| else: | |
| nat_in[m] = M.BLANK | |
| hn = core.emb(nat_in) | |
| for lpos, li in enumerate(layers): | |
| mode = _dblock_sublayer_mode_for_layer(args, state, bi, lpos) | |
| hn = _run_block(core.blocks[li], hn, None, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos, len(layers)), args, mode) | |
| Dnat = core.ln(hn) | |
| _profile_toc(state, "nat_forward", _t) | |
| _t = _profile_tic(prof) | |
| nat_hidden = Dnat[m] | |
| nat_targets = nat_ids[m] | |
| nat_hidden, nat_targets, nat_used, nat_total = _sample_token_loss_inputs( | |
| nat_hidden.unsqueeze(0), nat_targets.unsqueeze(0), int(getattr(args, "dblock_nat_loss_tokens", 0)) | |
| ) | |
| nat_raw = fused_ce(nat_hidden, nat_h.proj.weight, nat_targets) | |
| nat_raw_val = float(nat_raw.detach()) | |
| nat = nat_weight * nat_raw | |
| nat_val = float(nat.detach()) | |
| _profile_toc(state, "nat_ce", _t) | |
| _t = _profile_tic(prof) | |
| _aux = _collect_moe_aux(core, getattr(args,'moe_aux_coef',0.0), getattr(args,'moe_z_coef',0.0)) | |
| if torch.is_tensor(_aux): | |
| nat = nat + _aux.to(nat.dtype) | |
| scaler.scale(nat).backward() | |
| _profile_toc(state, "nat_backward", _t) | |
| del nat_ids, nat_in, m, hn, Dnat, nat_hidden, nat_targets, nat_raw, nat, nat_used, nat_total | |
| total_val = ar_val + sat_val + nat_val | |
| raw_total_val = ar_raw_val + sat_raw_val + nat_raw_val | |
| raw_count = int(bool(run_ar)) + int(bool(run_sat)) + int(bool(run_nat)) | |
| raw_avg_val = raw_total_val / max(1, raw_count) | |
| if not math.isfinite(total_val): | |
| opt.zero_grad(set_to_none=True) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"[dblock] non-finite loss {total_val}; skipped optimizer step", flush=True) | |
| _profile_toc(state, "step_total", _step_t) | |
| _profile_step_done(state, args) | |
| _update_stats(state, bi, total_val, args) | |
| return total_val | |
| _spike_k = float(getattr(args, "loss_spike_skip", 0.0)) | |
| if _spike_k > 0.0: | |
| _ema = state.get("spike_ema") | |
| if _ema is not None and math.isfinite(_ema) and math.isfinite(raw_avg_val) and raw_avg_val > _spike_k * _ema: | |
| opt.zero_grad(set_to_none=True) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"[dblock] loss spike raw_avg={raw_avg_val:.2f} > {_spike_k}x EMA={_ema:.2f}; skipped optimizer step", flush=True) | |
| _profile_toc(state, "step_total", _step_t) | |
| _profile_step_done(state, args) | |
| _update_stats(state, bi, total_val, args) | |
| return total_val | |
| if math.isfinite(raw_avg_val): | |
| state["spike_ema"] = raw_avg_val if _ema is None else (0.98 * _ema + 0.02 * raw_avg_val) | |
| _t = _profile_tic(prof) | |
| scaler.unscale_(opt) | |
| nn.utils.clip_grad_norm_([p for g in opt.param_groups for p in g["params"]], 1.0) | |
| scaler.step(opt) | |
| scaler.update() | |
| opt.zero_grad(set_to_none=True) | |
| _profile_toc(state, "opt_step", _t) | |
| peak_alloc = None | |
| peak_reserved = None | |
| if torch.cuda.is_available(): | |
| peak_alloc = torch.cuda.max_memory_allocated() / (1024**3) | |
| peak_reserved = torch.cuda.max_memory_reserved() / (1024**3) | |
| _profile_toc(state, "step_total", _step_t) | |
| _profile_step_done(state, args) | |
| _update_stats(state, bi, total_val, args) | |
| _maybe_log( | |
| state, | |
| args, | |
| bi, | |
| layers, | |
| ar_val, | |
| sat_val, | |
| nat_val, | |
| total_val, | |
| peak_alloc, | |
| peak_reserved, | |
| objective=objective, | |
| raw_avg=raw_avg_val, | |
| raw_total=raw_total_val, | |
| edm_weight=w, | |
| ) | |
| return raw_avg_val | |
| # ===== END dblocks_train.py ===== | |
| # ===== BEGIN nB300_agillm4.py ===== | |
| #!/usr/bin/env python3 | |
| # n.py - Joint AR+SAT+NAT Trainer with Expansion Ratio Testing | |
| # Enhanced inference: checkpoint name, tok/s, UK time | |
| import argparse, copy, json, math, pathlib, random, time, os, sys, threading, hashlib, re, subprocess | |
| from pathlib import Path | |
| from contextlib import nullcontext | |
| from typing import Dict, Any, List, Optional, Tuple | |
| from datetime import datetime, timezone | |
| _ASCII_LOG_TRANSLATION = str.maketrans({ | |
| "\u2018": "'", | |
| "\u2019": "'", | |
| "\u201a": "'", | |
| "\u201b": "'", | |
| "\u201c": '"', | |
| "\u201d": '"', | |
| "\u201e": '"', | |
| "\u201f": '"', | |
| "\u2013": "-", | |
| "\u2014": "-", | |
| "\u2212": "-", | |
| "\u2026": "...", | |
| "\u00a0": " ", | |
| }) | |
| def _ascii_log_text(text: str) -> str: | |
| return str(text).translate(_ASCII_LOG_TRANSLATION).encode("ascii", "replace").decode("ascii") | |
| class _AsciiLogStream: | |
| def __init__(self, wrapped): | |
| self._wrapped = wrapped | |
| def write(self, text): | |
| return self._wrapped.write(_ascii_log_text(text)) | |
| def flush(self): | |
| return self._wrapped.flush() | |
| def isatty(self): | |
| return self._wrapped.isatty() | |
| def fileno(self): | |
| return self._wrapped.fileno() | |
| def encoding(self): | |
| return "ascii" | |
| def __getattr__(self, name): | |
| return getattr(self._wrapped, name) | |
| if ( | |
| not sys.stdout.isatty() | |
| and os.environ.get("NB300_RAW_UNICODE_LOGS", "").lower() not in {"1", "true", "yes"} | |
| ): | |
| sys.stdout = _AsciiLogStream(sys.stdout) | |
| sys.stderr = _AsciiLogStream(sys.stderr) | |
| STATUS_SCRIPT_PATH = Path(__file__).resolve() | |
| STATUS_DEFAULT_LOG = STATUS_SCRIPT_PATH.parent / "train.log" | |
| STATUS_DEFAULT_SAVE_DIR = STATUS_SCRIPT_PATH.parent / "ckpts_expansion" | |
| _STATUS_PROGRESS_RE = re.compile( | |
| r"^\[(?P<percent>\d+(?:\.\d+)?)%\]\s+" | |
| r"(?P<seen>[\d,]+)/(?P<target>[\d,]+)\s+tok\s+\|\s+" | |
| r"(?P<tok_s>[\d.]+)\s+tok/s\s+\|\s+" | |
| r"loss=(?P<loss>-?[\d.]+)\s+B=(?P<batch>\d+)\s+L=(?P<block>\d+)" | |
| r"(?:\s+step=(?P<step>\d+))?" | |
| r"(?:\s+eta=(?P<eta>\S+))?" | |
| r"(?:\s+elapsed=(?P<elapsed>\S+))?" | |
| r"\s*$" | |
| ) | |
| _STATUS_DELTA_RE = re.compile(r"\[delta\]\s+saved\s+(?P<name>\S+?\.pt)\s+\((?P<sha>[0-9a-f]+)\.\.\.\)") | |
| _STATUS_STEP_RE = re.compile(r"step(?P<step>\d+)") | |
| def _status_iso(ts: Optional[float]) -> Optional[str]: | |
| if ts is None: | |
| return None | |
| return datetime.fromtimestamp(ts, tz=timezone.utc).astimezone().isoformat(timespec="seconds") | |
| def _status_human_duration(seconds: Optional[float]) -> Optional[str]: | |
| if seconds is None: | |
| return None | |
| total = max(0, int(seconds)) | |
| days, rem = divmod(total, 86400) | |
| hours, rem = divmod(rem, 3600) | |
| minutes, secs = divmod(rem, 60) | |
| parts = [] | |
| if days: | |
| parts.append(f"{days}d") | |
| if hours or parts: | |
| parts.append(f"{hours}h") | |
| if minutes or parts: | |
| parts.append(f"{minutes}m") | |
| parts.append(f"{secs}s") | |
| return " ".join(parts) | |
| def _status_compact_duration(seconds: Optional[float]) -> str: | |
| if seconds is None: | |
| return "unknown" | |
| try: | |
| if not math.isfinite(float(seconds)): | |
| return "unknown" | |
| except Exception: | |
| return "unknown" | |
| total = max(0, int(seconds)) | |
| years, rem = divmod(total, 365 * 86400) | |
| days, rem = divmod(rem, 86400) | |
| hours, rem = divmod(rem, 3600) | |
| minutes, secs = divmod(rem, 60) | |
| if years: | |
| return f"{years}y{days}d{hours}h" | |
| if days: | |
| return f"{days}d{hours}h{minutes}m" | |
| if hours: | |
| return f"{hours}h{minutes}m{secs}s" | |
| if minutes: | |
| return f"{minutes}m{secs}s" | |
| return f"{secs}s" | |
| def _status_format_int(value: Optional[int]) -> str: | |
| return "?" if value is None else f"{value:,}" | |
| def _status_parse_step(text: str) -> Optional[int]: | |
| match = _STATUS_STEP_RE.search(text) | |
| return int(match.group("step")) if match else None | |
| def _status_resolve_ckpt_path(raw_path: str, base_dir: Path) -> Path: | |
| ckpt_path = Path(raw_path) | |
| return ckpt_path if ckpt_path.is_absolute() else (base_dir / ckpt_path).resolve() | |
| def _status_read_cmdline(proc_dir: Path) -> Optional[List[str]]: | |
| try: | |
| data = (proc_dir / "cmdline").read_bytes().split(b"\0") | |
| return [item.decode("utf-8", errors="ignore") for item in data if item] | |
| except Exception: | |
| return None | |
| def _status_resolve_proc_arg(proc_dir: Path, raw_arg: str) -> Optional[Path]: | |
| try: | |
| arg_path = Path(raw_arg) | |
| if arg_path.is_absolute(): | |
| return arg_path.resolve() | |
| cwd = Path(os.readlink(proc_dir / "cwd")) | |
| return (cwd / arg_path).resolve() | |
| except Exception: | |
| return None | |
| def _status_proc_uptime(proc_dir: Path) -> Optional[float]: | |
| try: | |
| proc_uptime = float((Path("/proc") / "uptime").read_text().split()[0]) | |
| stat_text = (proc_dir / "stat").read_text() | |
| after = stat_text[stat_text.rfind(")") + 2:].split() | |
| start_ticks = float(after[19]) | |
| clock_ticks = os.sysconf(os.sysconf_names["SC_CLK_TCK"]) | |
| return max(0.0, proc_uptime - (start_ticks / clock_ticks)) | |
| except Exception: | |
| return None | |
| def _status_find_trainers(script_path: Path) -> List[Dict[str, Any]]: | |
| matches: List[Dict[str, Any]] = [] | |
| for proc_dir in Path("/proc").iterdir(): | |
| if not proc_dir.name.isdigit(): | |
| continue | |
| args = _status_read_cmdline(proc_dir) | |
| if not args or "train" not in args: | |
| continue | |
| resolved_script = None | |
| for arg in args: | |
| if Path(arg).name != script_path.name: | |
| continue | |
| candidate = _status_resolve_proc_arg(proc_dir, arg) | |
| if candidate == script_path: | |
| resolved_script = candidate | |
| break | |
| if resolved_script is None: | |
| continue | |
| uptime_seconds = _status_proc_uptime(proc_dir) | |
| try: | |
| cwd = str(Path(os.readlink(proc_dir / "cwd"))) | |
| except Exception: | |
| cwd = None | |
| matches.append({ | |
| "pid": int(proc_dir.name), | |
| "cmdline": " ".join(args), | |
| "args": args, | |
| "cwd": cwd, | |
| "uptime_seconds": round(uptime_seconds, 3) if uptime_seconds is not None else None, | |
| "uptime_human": _status_human_duration(uptime_seconds), | |
| }) | |
| return sorted(matches, key=lambda item: item["pid"]) | |
| def _status_parse_progress_line(line: str) -> Optional[Dict[str, Any]]: | |
| match = _STATUS_PROGRESS_RE.match(line.strip()) | |
| if not match: | |
| return None | |
| tok_per_sec = float(match.group("tok_s")) | |
| loss = float(match.group("loss")) | |
| return { | |
| "raw_line": line.strip(), | |
| "percent": float(match.group("percent")), | |
| "seen_tokens": int(match.group("seen").replace(",", "")), | |
| "target_tokens": int(match.group("target").replace(",", "")), | |
| "tok_per_sec": int(tok_per_sec) if tok_per_sec.is_integer() else tok_per_sec, | |
| "loss": loss, | |
| "batch": int(match.group("batch")), | |
| "block": int(match.group("block")), | |
| "step": int(match.group("step")) if match.group("step") else None, | |
| "eta": match.group("eta"), | |
| "elapsed": match.group("elapsed"), | |
| } | |
| def _status_parse_delta_line(line: str) -> Optional[Dict[str, Any]]: | |
| match = _STATUS_DELTA_RE.search(line) | |
| if not match: | |
| return None | |
| name = match.group("name") | |
| return { | |
| "raw_line": line.strip(), | |
| "name": name, | |
| "step": _status_parse_step(name), | |
| "sha_prefix": match.group("sha"), | |
| "source": "log", | |
| } | |
| def _status_scan_log(log_path: Path) -> tuple[Dict[str, Any], Optional[Dict[str, Any]], Optional[Dict[str, Any]], List[str]]: | |
| now = time.time() | |
| info: Dict[str, Any] = { | |
| "path": str(log_path), | |
| "exists": log_path.exists(), | |
| "mtime": None, | |
| "mtime_iso": None, | |
| "age_seconds": None, | |
| "age_human": None, | |
| "size_bytes": None, | |
| } | |
| warnings: List[str] = [] | |
| if not log_path.exists(): | |
| warnings.append(f"train log missing: {log_path}") | |
| return info, None, None, warnings | |
| try: | |
| st = log_path.stat() | |
| info["mtime"] = st.st_mtime | |
| info["mtime_iso"] = _status_iso(st.st_mtime) | |
| info["age_seconds"] = round(max(0.0, now - st.st_mtime), 3) | |
| info["age_human"] = _status_human_duration(info["age_seconds"]) | |
| info["size_bytes"] = st.st_size | |
| except Exception as exc: | |
| warnings.append(f"failed to stat train log: {exc}") | |
| last_progress = None | |
| last_delta = None | |
| try: | |
| with log_path.open("r", encoding="utf-8", errors="ignore") as handle: | |
| for raw_line in handle: | |
| line = raw_line.rstrip("\n") | |
| progress = _status_parse_progress_line(line) | |
| if progress is not None: | |
| last_progress = progress | |
| delta = _status_parse_delta_line(line) | |
| if delta is not None: | |
| last_delta = delta | |
| except Exception as exc: | |
| warnings.append(f"failed to read train log: {exc}") | |
| return info, last_progress, last_delta, warnings | |
| def _status_latest_full_checkpoint(save_dir: Path, base_dir: Path) -> tuple[Dict[str, Any], List[str]]: | |
| latest_path = save_dir / "latest.json" | |
| info: Dict[str, Any] = { | |
| "metadata_path": str(latest_path), | |
| "exists": latest_path.exists(), | |
| "raw_path": None, | |
| "checkpoint_path": None, | |
| "checkpoint_name": None, | |
| "checkpoint_exists": None, | |
| "step": None, | |
| "checkpoint_mtime": None, | |
| "checkpoint_mtime_iso": None, | |
| } | |
| warnings: List[str] = [] | |
| if not latest_path.exists(): | |
| warnings.append(f"latest.json missing: {latest_path}") | |
| return info, warnings | |
| try: | |
| payload = json.loads(latest_path.read_text(encoding="utf-8")) | |
| except Exception as exc: | |
| warnings.append(f"failed to parse latest.json: {exc}") | |
| return info, warnings | |
| raw_path = payload.get("path") | |
| info["raw_path"] = raw_path | |
| info["step"] = payload.get("step") | |
| if raw_path: | |
| ckpt_path = _status_resolve_ckpt_path(raw_path, base_dir) | |
| info["checkpoint_path"] = str(ckpt_path) | |
| info["checkpoint_name"] = ckpt_path.name | |
| info["checkpoint_exists"] = ckpt_path.exists() | |
| if ckpt_path.exists(): | |
| try: | |
| st = ckpt_path.stat() | |
| info["checkpoint_mtime"] = st.st_mtime | |
| info["checkpoint_mtime_iso"] = _status_iso(st.st_mtime) | |
| except Exception as exc: | |
| warnings.append(f"failed to stat full checkpoint: {exc}") | |
| else: | |
| warnings.append(f"latest.json points to missing checkpoint: {ckpt_path}") | |
| return info, warnings | |
| def _status_newest_delta(save_dir: Path) -> tuple[Optional[Dict[str, Any]], List[str]]: | |
| warnings: List[str] = [] | |
| if not save_dir.exists(): | |
| warnings.append(f"save dir missing: {save_dir}") | |
| return None, warnings | |
| try: | |
| candidates = [item for item in save_dir.glob("*_delta_step*.pt") if item.is_file()] | |
| except Exception as exc: | |
| warnings.append(f"failed to list delta checkpoints: {exc}") | |
| return None, warnings | |
| if not candidates: | |
| warnings.append(f"no delta checkpoints found in {save_dir}") | |
| return None, warnings | |
| newest = max(candidates, key=lambda item: item.stat().st_mtime) | |
| st = newest.stat() | |
| return { | |
| "path": str(newest), | |
| "name": newest.name, | |
| "step": _status_parse_step(newest.name), | |
| "mtime": st.st_mtime, | |
| "mtime_iso": _status_iso(st.st_mtime), | |
| "size_bytes": st.st_size, | |
| "source": "disk", | |
| }, warnings | |
| def _status_gpu_info() -> tuple[Optional[Dict[str, Any]], List[str]]: | |
| warnings: List[str] = [] | |
| try: | |
| result = subprocess.run( | |
| [ | |
| "nvidia-smi", | |
| "--query-gpu=name,utilization.gpu,memory.used,memory.total,temperature.gpu,power.draw", | |
| "--format=csv,noheader,nounits", | |
| ], | |
| capture_output=True, | |
| text=True, | |
| timeout=5, | |
| check=False, | |
| ) | |
| except FileNotFoundError: | |
| return None, warnings | |
| except Exception as exc: | |
| warnings.append(f"failed to query GPU status: {exc}") | |
| return None, warnings | |
| if result.returncode != 0: | |
| warnings.append(result.stderr.strip() or "nvidia-smi returned non-zero exit status") | |
| return None, warnings | |
| lines = [line.strip() for line in result.stdout.splitlines() if line.strip()] | |
| if not lines: | |
| return None, warnings | |
| if len(lines) > 1: | |
| warnings.append("multiple GPUs detected; reporting the first GPU only") | |
| parts = [part.strip() for part in lines[0].split(",")] | |
| if len(parts) != 6: | |
| warnings.append(f"unexpected nvidia-smi format: {lines[0]}") | |
| return None, warnings | |
| def _parse_int(raw: str) -> Optional[int]: | |
| try: | |
| return int(float(raw)) | |
| except Exception: | |
| return None | |
| def _parse_float(raw: str) -> Optional[float]: | |
| try: | |
| return float(raw) | |
| except Exception: | |
| return None | |
| return { | |
| "name": parts[0], | |
| "utilization_gpu": _parse_int(parts[1]), | |
| "memory_used_mib": _parse_int(parts[2]), | |
| "memory_total_mib": _parse_int(parts[3]), | |
| "temperature_c": _parse_int(parts[4]), | |
| "power_draw_w": _parse_float(parts[5]), | |
| }, warnings | |
| def _status_choose_delta(from_log: Optional[Dict[str, Any]], from_disk: Optional[Dict[str, Any]], warnings: List[str]) -> Optional[Dict[str, Any]]: | |
| if from_log and from_disk: | |
| log_step = from_log.get("step") | |
| disk_step = from_disk.get("step") | |
| if log_step is not None and disk_step is not None: | |
| if log_step != disk_step: | |
| warnings.append( | |
| f"log delta step {log_step} and newest on-disk delta step {disk_step} differ; using the newer step" | |
| ) | |
| if disk_step >= log_step: | |
| merged = dict(from_disk) | |
| merged["source"] = "disk+log" if disk_step == log_step else "disk" | |
| if disk_step == log_step: | |
| merged["sha_prefix"] = from_log.get("sha_prefix") | |
| return merged | |
| return dict(from_log) | |
| return dict(from_disk) | |
| if from_disk: | |
| return dict(from_disk) | |
| if from_log: | |
| return dict(from_log) | |
| return None | |
| def _collect_status(log_path: Path, save_dir: Path) -> tuple[Dict[str, Any], int]: | |
| checked_at = time.time() | |
| requested_save_dir = save_dir.expanduser() | |
| log_path = log_path.expanduser() | |
| status: Dict[str, Any] = { | |
| "checked_at": checked_at, | |
| "checked_at_iso": _status_iso(checked_at), | |
| "running": False, | |
| "process": None, | |
| "progress": None, | |
| "delta_checkpoint": None, | |
| "delta_from_log": None, | |
| "delta_on_disk": None, | |
| "latest_full_checkpoint": None, | |
| "log": None, | |
| "gpu": None, | |
| "save_dir": { | |
| "requested_path": str(requested_save_dir), | |
| "path": str(requested_save_dir), | |
| "exists": requested_save_dir.exists(), | |
| "source": "requested", | |
| }, | |
| "warnings": [], | |
| } | |
| warnings = status["warnings"] | |
| matches = _status_find_trainers(STATUS_SCRIPT_PATH) | |
| if len(matches) > 1: | |
| status["error"] = "multiple active n.py train processes found" | |
| status["processes"] = matches | |
| return status, 1 | |
| if matches: | |
| status["running"] = True | |
| status["process"] = matches[0] | |
| save_dir = requested_save_dir | |
| if status["process"] and status["process"].get("cwd"): | |
| proc_cwd = Path(status["process"]["cwd"]) | |
| alt_save_dir = (proc_cwd / requested_save_dir.name).resolve() | |
| if alt_save_dir != requested_save_dir and alt_save_dir.exists(): | |
| requested_delta, _ = _status_newest_delta(requested_save_dir) | |
| requested_full, _ = _status_latest_full_checkpoint(requested_save_dir, STATUS_SCRIPT_PATH.parent) | |
| alt_delta, _ = _status_newest_delta(alt_save_dir) | |
| alt_full, _ = _status_latest_full_checkpoint(alt_save_dir, proc_cwd) | |
| requested_score = int(requested_delta is not None) + int(bool(requested_full.get("checkpoint_exists"))) | |
| alt_score = int(alt_delta is not None) + int(bool(alt_full.get("checkpoint_exists"))) | |
| if alt_score > requested_score: | |
| save_dir = alt_save_dir | |
| status["save_dir"] = { | |
| "requested_path": str(requested_save_dir), | |
| "path": str(save_dir), | |
| "exists": save_dir.exists(), | |
| "source": "process_cwd_fallback", | |
| } | |
| warnings.append( | |
| f"using process cwd save dir fallback: {save_dir} (requested {requested_save_dir})" | |
| ) | |
| log_info, progress, delta_from_log, log_warnings = _status_scan_log(log_path) | |
| warnings.extend(log_warnings) | |
| status["log"] = log_info | |
| status["progress"] = progress | |
| status["delta_from_log"] = delta_from_log | |
| latest_base_dir = STATUS_SCRIPT_PATH.parent | |
| if status["save_dir"].get("source") == "process_cwd_fallback" and status["process"] and status["process"].get("cwd"): | |
| latest_base_dir = Path(status["process"]["cwd"]) | |
| latest_full, latest_warnings = _status_latest_full_checkpoint(save_dir, latest_base_dir) | |
| warnings.extend(latest_warnings) | |
| status["latest_full_checkpoint"] = latest_full | |
| delta_on_disk, delta_warnings = _status_newest_delta(save_dir) | |
| warnings.extend(delta_warnings) | |
| status["delta_on_disk"] = delta_on_disk | |
| status["delta_checkpoint"] = _status_choose_delta(delta_from_log, delta_on_disk, warnings) | |
| gpu, gpu_warnings = _status_gpu_info() | |
| warnings.extend(gpu_warnings) | |
| status["gpu"] = gpu | |
| if status["running"] and log_info.get("age_seconds") is not None and log_info["age_seconds"] > 600: | |
| warnings.append(f"train log appears stale while trainer is running ({log_info['age_human']} old)") | |
| if log_info.get("exists") and progress is None: | |
| warnings.append("no parseable progress line found in train log") | |
| latest_step = latest_full.get("step") if latest_full else None | |
| delta_step = status["delta_checkpoint"].get("step") if status["delta_checkpoint"] else None | |
| if latest_step is not None and delta_step is not None and latest_step < delta_step: | |
| warnings.append(f"latest.json step {latest_step} lags newest delta step {delta_step}") | |
| if not status["running"] and progress is None: | |
| warnings.append("no active trainer process found") | |
| return status, 0 | |
| def _format_status_text(status: Dict[str, Any]) -> str: | |
| lines = [f"AGILLM status @ {status.get('checked_at_iso')}"] | |
| if status.get("error"): | |
| lines.append(f"Error: {status['error']}") | |
| for proc in status.get("processes", []): | |
| lines.append(f"- pid {proc.get('pid')}: {proc.get('cmdline')}") | |
| return "\n".join(lines) | |
| process = status.get("process") | |
| if status.get("running") and process: | |
| lines.append(f"Process: RUNNING | pid {process.get('pid')} | uptime {process.get('uptime_human') or 'unknown'}") | |
| lines.append(f"Cmd: {process.get('cmdline')}") | |
| else: | |
| lines.append("Process: NOT RUNNING") | |
| progress = status.get("progress") | |
| if progress: | |
| eta = progress.get("eta") | |
| if not eta and progress.get("tok_per_sec"): | |
| remaining = max(0, progress["target_tokens"] - progress["seen_tokens"]) | |
| eta = _status_compact_duration(remaining / float(progress["tok_per_sec"])) | |
| lines.append( | |
| "Progress: " | |
| f"{progress['percent']:.1f}% | " | |
| f"{_status_format_int(progress['seen_tokens'])}/{_status_format_int(progress['target_tokens'])} tok | " | |
| f"{progress['tok_per_sec']} tok/s | loss {progress['loss']:.3f} | " | |
| f"B={progress['batch']} L={progress['block']}" | |
| + (f" | step {progress['step']}" if progress.get("step") else "") | |
| + (f" | ETA {eta}" if eta else "") | |
| ) | |
| else: | |
| lines.append("Progress: unavailable") | |
| log_info = status.get("log") or {} | |
| if log_info.get("exists"): | |
| lines.append( | |
| f"Log: {log_info.get('path')} | updated {log_info.get('age_human') or 'unknown'} ago | " | |
| f"mtime {log_info.get('mtime_iso')}" | |
| ) | |
| else: | |
| lines.append(f"Log: missing ({log_info.get('path')})") | |
| delta = status.get("delta_checkpoint") | |
| if delta: | |
| line = f"Delta: {delta.get('name')} | step {delta.get('step')} | source {delta.get('source')}" | |
| if delta.get("path"): | |
| line += f" | {delta['path']}" | |
| lines.append(line) | |
| else: | |
| lines.append("Delta: unavailable") | |
| latest_full = status.get("latest_full_checkpoint") or {} | |
| if latest_full.get("exists"): | |
| lines.append( | |
| f"Latest full: step {latest_full.get('step')} | {latest_full.get('checkpoint_path') or latest_full.get('raw_path')}" | |
| ) | |
| else: | |
| lines.append(f"Latest full: unavailable ({latest_full.get('metadata_path')})") | |
| gpu = status.get("gpu") | |
| if gpu: | |
| lines.append( | |
| f"GPU: {gpu.get('name')} | {gpu.get('utilization_gpu')}% | " | |
| f"{gpu.get('memory_used_mib')}/{gpu.get('memory_total_mib')} MiB | " | |
| f"{gpu.get('temperature_c')}C | {gpu.get('power_draw_w')} W" | |
| ) | |
| warnings = status.get("warnings") or [] | |
| if warnings: | |
| lines.append("Warnings:") | |
| lines.extend(f"- {warning}" for warning in warnings) | |
| return "\n".join(lines) | |
| def _emit_status(log_path: Path, save_dir: Path, as_json: bool) -> int: | |
| status, exit_code = _collect_status(log_path, save_dir) | |
| if as_json: | |
| print(json.dumps(status, indent=2, sort_keys=True)) | |
| else: | |
| print(_format_status_text(status)) | |
| return exit_code | |
| def _run_status_command(argv: List[str]) -> int: | |
| parser = argparse.ArgumentParser(prog=f"{STATUS_SCRIPT_PATH.name} status", description="Read-only training status") | |
| parser.add_argument("--json", dest="json_output", action="store_true", help="Emit machine-readable JSON") | |
| parser.add_argument("--log", type=Path, default=STATUS_DEFAULT_LOG, help="Path to the training log") | |
| parser.add_argument("--save_dir", type=Path, default=STATUS_DEFAULT_SAVE_DIR, help="Checkpoint directory") | |
| args = parser.parse_args(argv) | |
| return _emit_status(args.log, args.save_dir, args.json_output) | |
| def _maybe_handle_status_fastpath() -> None: | |
| if len(sys.argv) > 1 and sys.argv[1] == "status": | |
| raise SystemExit(_run_status_command(sys.argv[2:])) | |
| _maybe_handle_status_fastpath() | |
| import torch | |
| import torch.utils.checkpoint as torch_checkpoint | |
| # SafeProgress - Claude-safe progress (discrete lines, not single growing line) | |
| class SafeProgress: | |
| def __init__(self, total, initial=0, unit="tok", print_every=100, print_every_sec=60): | |
| self.total, self.n, self.unit = total, initial, unit | |
| self.initial = initial | |
| self.last_print, self.postfix = initial, {} | |
| self.print_every = max(1, int(print_every)) | |
| self.print_every_sec = max(1, int(print_every_sec)) | |
| self.step = 0 | |
| self.last_print_step = 0 | |
| self.start_time = __import__('time').time() | |
| self.last_print_time = self.start_time | |
| def update(self, n=1): | |
| self.n += n | |
| self.step += 1 | |
| now = __import__('time').time() | |
| if ( | |
| self.step == 1 | |
| or (self.step - self.last_print_step) >= self.print_every | |
| or (now - self.last_print_time) >= self.print_every_sec | |
| ): | |
| self._print(now) | |
| self.last_print = self.n | |
| self.last_print_step = self.step | |
| self.last_print_time = now | |
| def set_postfix(self, **kwargs): self.postfix = kwargs | |
| def _print(self, now=None): | |
| now = now or __import__('time').time() | |
| elapsed = now - self.start_time | |
| rate = (self.n - self.initial) / elapsed if elapsed > 0 else 0 | |
| pct = 100 * self.n / self.total if self.total > 0 else 0 | |
| pf = ' '.join(f"{k}={v}" for k,v in self.postfix.items()) | |
| remaining = max(0, self.total - self.n) | |
| eta = _status_compact_duration(remaining / rate) if rate > 0 else "unknown" | |
| elapsed_s = _status_compact_duration(elapsed) | |
| print( | |
| f"[{pct:.4f}%] {self.n:,}/{self.total:,} {self.unit} | " | |
| f"{rate:.2f} tok/s | {pf} step={self.step} eta={eta} elapsed={elapsed_s}", | |
| flush=True, | |
| ) | |
| def close(self): self._print(); print("Done.", flush=True) | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import signal | |
| import os | |
| from datasets import load_dataset, DownloadConfig | |
| from transformers import AutoTokenizer, logging as hf_log | |
| # from tqdm.auto import tqdm # DISABLED - kills Claude context | |
| # ─────────────────────────────── HOT DATASET LOADING ─────────────────────────────── | |
| HOT_CONFIG_PATH = Path("/workspace/hot_config.json") | |
| _hot_config_cache = {"mtime": 0, "data": {}} | |
| def get_hot_config() -> dict: | |
| """Load hot_config.json with caching, return empty dict if missing""" | |
| try: | |
| if HOT_CONFIG_PATH.exists(): | |
| mtime = HOT_CONFIG_PATH.stat().st_mtime | |
| if mtime > _hot_config_cache["mtime"]: | |
| with open(HOT_CONFIG_PATH) as f: | |
| _hot_config_cache["data"] = json.load(f) | |
| _hot_config_cache["mtime"] = mtime | |
| return _hot_config_cache["data"] | |
| except Exception as e: | |
| print(f"[hot_config] Error loading: {e}") | |
| return {} | |
| def get_hot_datasets(default_sources: str) -> str: | |
| """Get datasets from hot_config if present, else use default""" | |
| cfg = get_hot_config() | |
| if "datasets" in cfg and cfg["datasets"]: | |
| hot_ds = cfg["datasets"] | |
| if isinstance(hot_ds, list): | |
| hot_ds = ",".join(hot_ds) | |
| print(f"[hot_config] Using hot datasets: {hot_ds}") | |
| return hot_ds | |
| return default_sources | |
| # DISABLED: # Auto-rotating log to prevent context-window suicide | |
| # DISABLED: try: | |
| # DISABLED: from rotating_log import install_rotating_log | |
| # DISABLED: install_rotating_log() | |
| # DISABLED: except ImportError: | |
| # pass # Running without rotation | |
| # ───────────────────────── ASCII Sanitizer ───────────────────────── | |
| def _ascii_safe(s): | |
| if not isinstance(s, str): | |
| return s | |
| return (s | |
| .replace('\u2019', "'").replace('\u2018', "'") | |
| .replace('\u201C', '"').replace('\u201D', '"') | |
| .replace('\u2014', '-').replace('\u2013', '-') | |
| .replace('\u2026', '...') | |
| .replace('\u00A0', ' ')) | |
| # ───────────────────────── ANSI Colors ───────────────────────── | |
| class Colors: | |
| RESET = "\033[0m" | |
| BOLD = "\033[1m" | |
| PROMPT = "\033[36m" | |
| GEN = "\033[0m" | |
| INFO = "\033[90m" | |
| WARN = "\033[93m" | |
| # ───────────────────────── Globals ───────────────────────── | |
| hf_log.set_verbosity_error() | |
| DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| try: | |
| torch.set_float32_matmul_precision("high") | |
| except Exception: | |
| pass | |
| TOKENIZER_ID = os.environ.get("TOKENIZER_ID", "deepseek-ai/DeepSeek-V4-Pro") | |
| SYNTHETIC_TOKENIZER = os.environ.get("AGILLM_SYNTHETIC_TOKENIZER", "").lower() in {"1", "true", "yes"} | |
| class _SyntheticTokenizer: | |
| pad_token = "<|pad|>" | |
| pad_token_id = 0 | |
| eos_token_id = 1 | |
| sep_token_id = 1 | |
| def __init__(self, vocab_size: int): | |
| self.vocab_size = vocab_size | |
| self.backend_tokenizer = self | |
| def add_special_tokens(self, _tokens): | |
| return 0 | |
| def get_vocab(self): | |
| return {f"tok_{i}": i for i in range(self.vocab_size)} | |
| def encode(self, text): | |
| return [2 + (ord(ch) % max(1, self.vocab_size - 2)) for ch in str(text)] | |
| def decode(self, ids, skip_special_tokens=True): | |
| return " ".join(f"tok{int(i)}" for i in ids if not skip_special_tokens or int(i) > 1) | |
| def to_str(self): | |
| return json.dumps({"type": "synthetic", "vocab_size": self.vocab_size}) | |
| if SYNTHETIC_TOKENIZER: | |
| tok = _SyntheticTokenizer(int(os.environ.get("AGILLM_SYNTHETIC_VOCAB", "8192"))) | |
| print(f"[tokenizer] synthetic tokenizer enabled vocab={tok.vocab_size}") | |
| else: | |
| _tok_src = os.environ.get("TOKENIZER_DIR", "/workspace/tokenizers/deepseek-v4-pro") | |
| if not os.path.isdir(_tok_src): | |
| _tok_src = TOKENIZER_ID | |
| try: | |
| tok = AutoTokenizer.from_pretrained(_tok_src, use_fast=True, trust_remote_code=True, local_files_only=True) | |
| except Exception as _tok_exc: | |
| print(f"[tokenizer] offline load from {_tok_src} failed ({_tok_exc}); network fallback {TOKENIZER_ID}", flush=True) | |
| tok = AutoTokenizer.from_pretrained(TOKENIZER_ID, use_fast=True, trust_remote_code=True) | |
| if tok.pad_token is None: | |
| tok.add_special_tokens({"pad_token": "<|pad|>"}) | |
| # ─── Fix tokenizer Ġ/▁ mismatch ─── | |
| # Some DeepSeek tokenizer releases use Ġ (U+0120) for space-prefixed tokens, | |
| # but some transformers versions set the Metaspace pre-tokenizer to use | |
| # ▁ (U+2581) instead, causing encode/decode to lose all spaces. | |
| def _set_backend_tokenizer(tokenizer, backend) -> None: | |
| """Swap a fast tokenizer backing tokenizers.Tokenizer across transformers versions. | |
| Modern transformers expose backend_tokenizer as a READ-ONLY property backed by | |
| _tokenizer; older versions allow direct assignment. Setting _tokenizer is what makes | |
| the checkpoint tokenizer-restore actually take effect (it was failing silently).""" | |
| try: | |
| tokenizer._tokenizer = backend | |
| return | |
| except Exception: | |
| pass | |
| tokenizer.backend_tokenizer = backend | |
| def _tokenizer_payload() -> dict: | |
| """Embed enough tokenizer state for checkpoints/deltas to be self-contained. | |
| tokenizer_json is the exact fast-tokenizer backend. tokenizer_bundle stores the | |
| small save_pretrained() files as text for environments that need config/special | |
| token metadata too. This is intentionally best-effort so a tokenizer hiccup never | |
| aborts a model save. | |
| """ | |
| out = {"tokenizer_payload_schema": 2} | |
| try: | |
| out["tokenizer_id"] = TOKENIZER_ID | |
| except Exception: | |
| pass | |
| try: | |
| out["tokenizer_json"] = tok.backend_tokenizer.to_str() | |
| except Exception as e: | |
| print(f"[tokenizer] WARNING: could not embed tokenizer_json in checkpoint: {e}") | |
| try: | |
| out["tokenizer_special"] = { | |
| "pad_token": getattr(tok, "pad_token", None), | |
| "pad_token_id": getattr(tok, "pad_token_id", None), | |
| "eos_token": getattr(tok, "eos_token", None), | |
| "eos_token_id": getattr(tok, "eos_token_id", None), | |
| "sep_token": getattr(tok, "sep_token", None), | |
| "sep_token_id": getattr(tok, "sep_token_id", None), | |
| "vocab_size": len(tok.get_vocab()) if hasattr(tok, "get_vocab") else None, | |
| } | |
| except Exception: | |
| pass | |
| try: | |
| import tempfile | |
| bundle = {} | |
| with tempfile.TemporaryDirectory(prefix="agillm_tok_") as td: | |
| tok.save_pretrained(td) | |
| for item in Path(td).iterdir(): | |
| if item.is_file() and item.stat().st_size <= 64 * 1024 * 1024: | |
| try: | |
| bundle[item.name] = item.read_text(encoding="utf-8") | |
| except UnicodeDecodeError: | |
| import base64 | |
| bundle[item.name] = {"base64": base64.b64encode(item.read_bytes()).decode("ascii")} | |
| if bundle: | |
| out["tokenizer_bundle"] = bundle | |
| except Exception as e: | |
| print(f"[tokenizer] WARNING: could not embed tokenizer bundle in checkpoint: {e}") | |
| return out | |
| def _tokenizer_sidecar_paths(path): | |
| try: | |
| p = Path(path) | |
| except Exception: | |
| return [] | |
| return [ | |
| Path(str(p) + ".tokenizer.json"), | |
| p.with_suffix(p.suffix + ".tokenizer.json"), | |
| p.parent / (p.name + ".tokenizer.json"), | |
| ] | |
| def _read_tokenizer_sidecar(path): | |
| import json as _json | |
| if not path: | |
| return {} | |
| for sidecar in _tokenizer_sidecar_paths(path): | |
| try: | |
| if sidecar.exists(): | |
| obj = _json.loads(sidecar.read_text(encoding="utf-8")) | |
| if isinstance(obj, dict): | |
| obj.setdefault("tokenizer_sidecar", str(sidecar)) | |
| return obj | |
| except Exception as exc: | |
| print(f"[tokenizer] WARNING: could not read tokenizer sidecar {sidecar}: {exc}") | |
| return {} | |
| def _write_tokenizer_sidecar(path, payload) -> None: | |
| """Write tokenizer metadata beside a full checkpoint and as latest.tokenizer.json.""" | |
| try: | |
| p = Path(path) | |
| data = dict(payload or {}) | |
| if data.get("tokenizer_json") and not data.get("tokenizer_payload_schema"): | |
| data["tokenizer_payload_schema"] = 2 | |
| data.setdefault("tokenizer_payload_schema", 2) | |
| data["checkpoint_name"] = p.name | |
| data["checkpoint_path"] = str(p) | |
| for out in (Path(str(p) + ".tokenizer.json"), p.parent / "latest.tokenizer.json"): | |
| tmp = Path(str(out) + ".tmp") | |
| tmp.write_text(json.dumps(data, ensure_ascii=False, sort_keys=True), encoding="utf-8") | |
| tmp.replace(out) | |
| except Exception as exc: | |
| print(f"[tokenizer] WARNING: could not write tokenizer sidecar for {path}: {exc}") | |
| def _apply_tokenizer_special(payload) -> None: | |
| try: | |
| spec = payload.get("tokenizer_special") if hasattr(payload, "get") else None | |
| if not isinstance(spec, dict): | |
| return | |
| if spec.get("pad_token") is not None: | |
| tok.pad_token = spec.get("pad_token") | |
| if spec.get("eos_token") is not None: | |
| tok.eos_token = spec.get("eos_token") | |
| if spec.get("sep_token") is not None: | |
| tok.sep_token = spec.get("sep_token") | |
| except Exception as exc: | |
| print(f"[tokenizer] WARNING: special-token restore skipped: {exc}") | |
| def _restore_tokenizer_from_ckpt(d, ckpt_path=None) -> None: | |
| """Make tok match what a checkpoint/delta was trained with. | |
| Embedded tokenizer_json is exact and preferred. A sidecar produced for older | |
| checkpoints is next. Runtime TOKENIZER_ID is last-resort compatibility only. | |
| Never raises: a tokenizer issue must not abort load/infer. | |
| """ | |
| try: | |
| payload = d if hasattr(d, "get") else {} | |
| if ckpt_path: | |
| sidecar = _read_tokenizer_sidecar(ckpt_path) | |
| if sidecar: | |
| merged = dict(sidecar) | |
| # Embedded checkpoint fields win, but sidecars can fill schema, | |
| # special-token metadata, or bundle files missing from old saves. | |
| merged.update({k: v for k, v in payload.items() if str(k).startswith("tokenizer_") and v is not None}) | |
| payload = merged | |
| tj = payload.get("tokenizer_json") if hasattr(payload, "get") else None | |
| if tj: | |
| from tokenizers import Tokenizer as _Tokenizer | |
| _set_backend_tokenizer(tok, _Tokenizer.from_str(tj)) | |
| _apply_tokenizer_special(payload) | |
| source = payload.get("tokenizer_sidecar") or "checkpoint" | |
| print(f"[tokenizer] Restored from {source}") | |
| return | |
| tid = payload.get("tokenizer_id") if hasattr(payload, "get") else None | |
| if tid and tid != TOKENIZER_ID: | |
| print(f"[tokenizer] WARNING: checkpoint trained with tokenizer_id={tid} but runtime TOKENIZER_ID={TOKENIZER_ID}; set TOKENIZER_ID to match") | |
| elif tid: | |
| print(f"[tokenizer] checkpoint tokenizer_id={tid} matches runtime (no embedded json)") | |
| else: | |
| print("[tokenizer] no tokenizer embedded in checkpoint; using runtime default") | |
| except Exception as e: | |
| print(f"[tokenizer] WARNING: tokenizer restore skipped: {e}") | |
| def _fix_tokenizer_space_mismatch(tokenizer): | |
| try: | |
| import json as _json | |
| from tokenizers import Tokenizer as _Tokenizer | |
| bt = tokenizer.backend_tokenizer | |
| tj = _json.loads(bt.to_str()) | |
| pre = tj.get("pre_tokenizer", {}) | |
| needs_fix = (pre.get("type") == "Metaspace" and pre.get("replacement") == "\u2581") | |
| if not needs_fix: | |
| return | |
| # Check if vocab actually uses Ġ (U+0120) for spaces | |
| vocab = tj.get("model", {}).get("vocab", {}) | |
| has_gpt2_space = any(k.startswith("\u0120") for k in list(vocab.keys())[:500]) | |
| if not has_gpt2_space: | |
| return | |
| # Patch pre_tokenizer: ▁ -> Ġ | |
| tj["pre_tokenizer"]["replacement"] = "\u0120" | |
| # Patch decoder: ▁ -> Ġ in Replace step | |
| for step in tj.get("decoder", {}).get("decoders", []): | |
| if step.get("type") == "Replace": | |
| pat = step.get("pattern", {}) | |
| if pat.get("String") == "\u2581": | |
| pat["String"] = "\u0120" | |
| # Rebuild backend tokenizer | |
| fixed = _Tokenizer.from_str(_json.dumps(tj)) | |
| _set_backend_tokenizer(tokenizer, fixed) | |
| # Verify fix | |
| test_ids = tokenizer.encode("hello world") | |
| test_dec = tokenizer.decode(test_ids, skip_special_tokens=True) | |
| if "hello world" in test_dec: | |
| print("[tokenizer] Fixed Ġ/▁ space mismatch") | |
| else: | |
| print(f"[tokenizer] WARNING: fix applied but decode test failed: {repr(test_dec)}") | |
| except Exception as e: | |
| print(f"[tokenizer] Could not fix space mismatch: {e}") | |
| if not SYNTHETIC_TOKENIZER: | |
| _fix_tokenizer_space_mismatch(tok) | |
| # ─── Tokenizer startup health check ─── | |
| # Abort early if tokenizer can't roundtrip spaces — prevents silent data corruption | |
| def _tokenizer_health_check(tokenizer): | |
| import transformers as _tf | |
| ver = _tf.__version__ | |
| print(f"[tokenizer] transformers={ver}, tokenizers={__import__('tokenizers').__version__}") | |
| # Warn on known-bad versions | |
| try: | |
| from packaging.version import Version | |
| if Version(ver) >= Version('5.0.0'): | |
| print(f'[tokenizer] WARNING: transformers {ver} may have Metaspace bug — verify carefully') | |
| except ImportError: | |
| pass | |
| # Roundtrip tests — must preserve spaces | |
| tests = [ | |
| 'Water boils at one hundred degrees', | |
| 'The quick brown fox jumps over the lazy dog', | |
| 'Hello world! This is a test sentence with spaces.', | |
| ] | |
| for text in tests: | |
| ids = tokenizer.encode(text) | |
| decoded = tokenizer.decode(ids, skip_special_tokens=True) | |
| if ' ' not in decoded: | |
| print(f'[tokenizer] FATAL: Roundtrip lost all spaces!') | |
| print(f' Input: {repr(text)}') | |
| print(f' Encoded: {ids[:20]}...') | |
| print(f' Decoded: {repr(decoded)}') | |
| print(f'[tokenizer] ABORTING — fix tokenizer before training!') | |
| sys.exit(1) | |
| # Check decoded is reasonably close to input | |
| if text.lower().split()[:3] != decoded.lower().split()[:3]: | |
| print(f'[tokenizer] WARNING: Roundtrip diverged:') | |
| print(f' Input: {repr(text[:60])}') | |
| print(f' Decoded: {repr(decoded[:60])}') | |
| print(f'[tokenizer] Health check PASSED — spaces preserved in roundtrip') | |
| if not SYNTHETIC_TOKENIZER: | |
| _tokenizer_health_check(tok) | |
| VOCAB, BLANK, EOS = ( | |
| max(tok.get_vocab().values()) + 1, | |
| int(getattr(tok, "pad_token_id", 0) or 0), | |
| tok.eos_token_id if tok.eos_token_id is not None else tok.sep_token_id | |
| ) | |
| # ───────────────────────── PRESETS ───────────────────────── | |
| PRESETS: Dict[str, Dict[str, int]] = { | |
| "femto_1x": dict(d=16, layers=1, heads=1, rank=16), | |
| "femto_12x": dict(d=16, layers=1, heads=1, rank=192), | |
| "femto_24x": dict(d=16, layers=1, heads=1, rank=384), | |
| "pico_1x": dict(d=32, layers=1, heads=2, rank=16), | |
| "pico_3x": dict(d=32, layers=1, heads=2, rank=48), | |
| "pico_6x": dict(d=32, layers=1, heads=2, rank=96), | |
| "pico_12x": dict(d=32, layers=1, heads=2, rank=192), | |
| "pico_24x": dict(d=32, layers=1, heads=2, rank=384), | |
| "pico_48x": dict(d=32, layers=1, heads=2, rank=768), | |
| "nano_1x": dict(d=64, layers=2, heads=4, rank=16), | |
| "nano_3x": dict(d=64, layers=2, heads=4, rank=48), | |
| "nano_6x": dict(d=64, layers=2, heads=4, rank=96), | |
| "nano_12x": dict(d=64, layers=2, heads=4, rank=192), | |
| "nano_24x": dict(d=64, layers=2, heads=4, rank=384), | |
| "nano_48x": dict(d=64, layers=2, heads=4, rank=768), | |
| "nano_96x": dict(d=64, layers=2, heads=4, rank=1536), | |
| "micro_3x": dict(d=128, layers=4, heads=8, rank=48), | |
| "micro_6x": dict(d=128, layers=4, heads=8, rank=96), | |
| "micro_12x": dict(d=128, layers=4, heads=8, rank=192), | |
| "micro_24x": dict(d=128, layers=4, heads=8, rank=384), | |
| "small": dict(d=512, layers=8, heads=16, rank=64), | |
| "smallx2": dict(d=512, layers=16, heads=16, rank=64), | |
| "base": dict(d=768, layers=12, heads=24, rank=96), | |
| "base18": dict(d=768, layers=18, heads=24, rank=96), | |
| "large": dict(d=1024, layers=24, heads=16, rank=128), | |
| # AGILLM-4 tiers. These are intentionally above the ~700M AGILLM-3 size. | |
| # Approx dense parameter count with the current untied embedding+AR+SAT+NAT heads: | |
| # agillm4_floor ~= 1.21B, agillm4_main ~= 1.70B, agillm4_big ~= 2.40B. | |
| "agillm4_floor": dict(d=1280, layers=28, heads=20, rank=160), | |
| "agillm4_main": dict(d=1536, layers=32, heads=24, rank=192), | |
| "agillm4_big": dict(d=1792, layers=36, heads=28, rank=224), | |
| } | |
| DEFAULT_BLOCK = 1122 | |
| DEFAULT_BATCH = 4 | |
| SAT_BLOCK = 2 | |
| LR_CORE, LR_HEAD = 5e-5, 2e-4 | |
| EMIT_LAMBDA = 0.1 | |
| DEFAULT_SAVE_SEC = 24 * 3600 | |
| DEFAULT_DELTA_STEPS = 100000 # lightweight weight-only save every N steps | |
| DEFAULT_MAX_DELTAS = 5 # keep last N deltas (older pruned after full save) | |
| CKDIR = pathlib.Path("ckpts_expansion") | |
| DEFAULT_PRETRAIN_SOURCES = "LLM360/TxT360,OpenTransformer/goddess-crawl,OpenTransformer/agillm-crawl-data,OpenTransformer/web-crawl-2026,OpenTransformer/web-crawl-clean-v2,OpenTransformer/scraped-web-data,OpenTransformer/turbo-crawl,OpenTransformer/sft-data-clean,OpenTransformer/web-crawl-v1,HuggingFaceFW/fineweb,wikimedia/wikipedia:20231101.en,allenai/c4:en,EleutherAI/proof-pile-2" | |
| DEFAULT_AFTER_SFT_SOURCES = "mlabonne/opc-sft-stage2-chat,HuggingFaceH4/ultrachat_200k@train_sft" | |
| DEFAULT_AFTER_SFT_BLOCK = 768 | |
| DEFAULT_ATTN_BACKEND = os.environ.get("AGILLM_ATTN_BACKEND", "manual") | |
| def _env_int(name: str, default: int) -> int: | |
| try: | |
| return int(os.environ.get(name, default)) | |
| except (TypeError, ValueError): | |
| return default | |
| DEFAULT_SUBLINEAR_WINDOW = _env_int("AGILLM_SUBLINEAR_WINDOW", 256) | |
| DEFAULT_SUBLINEAR_STRIDE = _env_int("AGILLM_SUBLINEAR_STRIDE", 64) | |
| DEFAULT_SUBLINEAR_MAX_ANCHORS = _env_int("AGILLM_SUBLINEAR_MAX_ANCHORS", 256) | |
| DEFAULT_SUBLINEAR_CHUNK = _env_int("AGILLM_SUBLINEAR_CHUNK", 128) | |
| DEFAULT_SUBLINEAR_SINKS = _env_int("AGILLM_SUBLINEAR_SINKS", 4) | |
| DEFAULT_SUBLINEAR_RECENT_ANCHORS = _env_int("AGILLM_SUBLINEAR_RECENT_ANCHORS", -1) # -1 = half of max anchors | |
| DEFAULT_SUBLINEAR_POOLED_LANDMARKS = bool(_env_int("AGILLM_SUBLINEAR_POOLED_LANDMARKS", 0)) | |
| DEFAULT_ANCHOR_MEMORY = bool(_env_int("AGILLM_ANCHOR_MEMORY", 0)) | |
| DEFAULT_ANCHOR_STRIDE = _env_int("AGILLM_ANCHOR_STRIDE", 256) | |
| DEFAULT_ANCHOR_MAX = _env_int("AGILLM_ANCHOR_MAX", 2048) | |
| DEFAULT_ANCHOR_POSITION = _env_int("AGILLM_ANCHOR_POSITION", -1) # -1 = stack middle | |
| DEFAULT_KV_BUFFER = bool(_env_int("AGILLM_KV_BUFFER", 0)) | |
| DEFAULT_MOE_FFN = bool(_env_int("AGILLM_MOE_FFN", 0)) | |
| DEFAULT_MOE_EXPERTS = _env_int("AGILLM_MOE_EXPERTS", 4) | |
| DEFAULT_MOE_TOP_K = _env_int("AGILLM_MOE_TOP_K", 1) | |
| DEFAULT_MOE_MLP_MULT = _env_int("AGILLM_MOE_MLP_MULT", 4) | |
| AGILLM4_TOKEN_PARAM_RATIO = 100.0 | |
| # ───────────────────────── UK Time Helper ───────────────────────── | |
| def get_uk_time() -> str: | |
| utc_now = datetime.now(timezone.utc) | |
| year = utc_now.year | |
| march_last = datetime(year, 3, 31, 1, 0, tzinfo=timezone.utc) | |
| while march_last.weekday() != 6: | |
| march_last = march_last.replace(day=march_last.day - 1) | |
| oct_last = datetime(year, 10, 31, 1, 0, tzinfo=timezone.utc) | |
| while oct_last.weekday() != 6: | |
| oct_last = oct_last.replace(day=oct_last.day - 1) | |
| if march_last <= utc_now < oct_last: | |
| uk_offset = 1 | |
| tz_name = "BST" | |
| else: | |
| uk_offset = 0 | |
| tz_name = "GMT" | |
| from datetime import timedelta | |
| uk_time = utc_now + timedelta(hours=uk_offset) | |
| return uk_time.strftime(f'%Y-%m-%d %H:%M:%S {tz_name}') | |
| # ───────────────────────── Utilities ───────────────────────── | |
| def rng_state(): | |
| if DEV.type == "cuda": | |
| try: | |
| return torch.cuda.get_rng_state(DEV) | |
| except TypeError: | |
| return torch.cuda.get_rng_state() | |
| return torch.get_rng_state() | |
| def _is_probably_ckpt(path: pathlib.Path) -> bool: | |
| try: | |
| return path.is_file() and path.suffix == ".pt" and not path.name.endswith(".pt.tmp") and path.stat().st_size > (1<<20) | |
| except Exception: | |
| return False | |
| def _resolve_ckpt(path: pathlib.Path) -> pathlib.Path | None: | |
| try: | |
| if path.is_dir(): | |
| cands = sorted([p for p in path.glob("*.pt") if _is_probably_ckpt(p)], | |
| key=lambda p: p.stat().st_mtime, reverse=True) | |
| return cands[0] if cands else None | |
| if path.suffix == ".tmp": | |
| solid = path.with_suffix("") | |
| return solid if _is_probably_ckpt(solid) else _resolve_ckpt(path.parent) | |
| return path if _is_probably_ckpt(path) else _resolve_ckpt(path.parent) | |
| except Exception: | |
| return None | |
| def _try_load(path: pathlib.Path, map_location="cpu"): | |
| try: | |
| return torch.load(path, map_location="cpu") | |
| except Exception as e: | |
| print(f"[ckpt-skip] {path} not usable: {e}") | |
| return None | |
| def _prune_checkpoints(save_dir: pathlib.Path, phase_name: str, max_ckpts: int): | |
| if max_ckpts is None or max_ckpts <= 0: | |
| return | |
| try: | |
| pattern = f"{phase_name}_step*.pt" | |
| ckpts = sorted( | |
| [p for p in save_dir.glob(pattern) if _is_probably_ckpt(p)], | |
| key=lambda p: p.stat().st_mtime | |
| ) | |
| excess = len(ckpts) - max_ckpts | |
| if excess > 0: | |
| for p in ckpts[:excess]: | |
| try: | |
| p.unlink() | |
| print(f" [prune] deleted old {p.name}") | |
| except Exception: | |
| pass | |
| except Exception as e: | |
| print(f"[ckpt-prune] error: {e}") | |
| def print_expansion_info(cfg: dict, tie_weights: bool = False, plain: bool = False): | |
| d_k = cfg["d"] // cfg["heads"] | |
| rank = cfg["rank"] | |
| ratio = rank / d_k | |
| regime = "COMPRESSION" if ratio < 1 else ("IDENTITY" if ratio == 1 else "EXPANSION") | |
| tie_str = "YES" if tie_weights else "NO" | |
| if plain: | |
| print("[attention_config]") | |
| print(f"d_model={cfg['d']} heads={cfg['heads']} d_k={d_k}") | |
| print(f"layers={cfg['layers']} tie_weights={tie_str}") | |
| print(f"rank={rank} ratio={ratio:.1f}x regime={regime}") | |
| return | |
| print(f"┌─────────────────────────────────────────┐") | |
| print(f"│ TUNEABLE ATTENTION CONFIG │") | |
| print(f"├─────────────────────────────────────────┤") | |
| print(f"│ d_model: {cfg['d']:4d} heads: {cfg['heads']:2d} d_k: {d_k:3d} │") | |
| print(f"│ layers: {cfg['layers']:4d} tie_weights: {tie_str:3s} │") | |
| print(f"│ rank: {rank:4d} ratio: {ratio:.1f}x [{regime:11s}] │") | |
| print(f"└─────────────────────────────────────────┘") | |
| # ───────────────────────── AMP helper ───────────────────────── | |
| try: | |
| from torch.amp import autocast as _ac, GradScaler | |
| except ImportError: | |
| from torch.cuda.amp import autocast as _ac, GradScaler | |
| def _auto_amp_dtype(): | |
| if DEV.type == "cuda": | |
| try: | |
| if torch.cuda.is_bf16_supported(): return torch.bfloat16 | |
| return torch.float16 | |
| except Exception: return torch.float16 | |
| return torch.float32 | |
| def amp(enabled: bool): | |
| if not enabled or DEV.type != "cuda": | |
| return nullcontext() | |
| dtype = _auto_amp_dtype() | |
| if hasattr(torch, "amp") and hasattr(torch.amp, "autocast"): | |
| try: | |
| return torch.amp.autocast("cuda", dtype=dtype) | |
| except TypeError: | |
| try: | |
| return torch.amp.autocast(device_type="cuda", dtype=dtype) | |
| except TypeError: | |
| pass | |
| return torch.cuda.amp.autocast(dtype=dtype) | |
| def _needs_grad_scaler() -> bool: | |
| return bool(DEV.type == "cuda" and _auto_amp_dtype() == torch.float16) | |
| # ───────────────────────── Chat & Data Stream ───────────────────────── | |
| def _coerce_role(r: str) -> str: | |
| r = (r or "").lower() | |
| if r in {"user", "human", "customer"}: return "user" | |
| if r in {"assistant", "gpt", "bot"}: return "assistant" | |
| if r in {"system", "context"}: return "system" | |
| return r or "user" | |
| def _chat_content(m: dict) -> str: | |
| content = m.get("content", m.get("text", m.get("value", ""))) | |
| return content if isinstance(content, str) else "" | |
| def _chat_role(m: dict) -> str: | |
| return _coerce_role(m.get("role", m.get("from", m.get("speaker", "")))) | |
| def _fallback_chat_template(messages: list[dict], add_generation_prompt: bool) -> str: | |
| parts = [] | |
| for m in messages: | |
| role = _chat_role(m) | |
| content = _chat_content(m).strip() | |
| if not content: | |
| continue | |
| if role == "system": | |
| parts.append(f"System: {content}") | |
| elif role == "assistant": | |
| parts.append(f"Assistant: {content}") | |
| else: | |
| parts.append(f"User: {content}") | |
| if add_generation_prompt and (not parts or not parts[-1].startswith("Assistant:")): | |
| parts.append("Assistant:") | |
| return "\n".join(parts) | |
| def _render_chat_text_from_ex(ex: dict, messages_key: str, add_generation_prompt: bool) -> Optional[str]: | |
| msgs = ex.get(messages_key) | |
| if msgs is None: | |
| for alt in ("conversations", "dialog", "turns"): | |
| if isinstance(ex.get(alt), list): | |
| msgs = ex[alt]; break | |
| if isinstance(msgs, list) and msgs and isinstance(msgs[0], dict): | |
| norm = [] | |
| for m in msgs: | |
| content = _chat_content(m) | |
| if not isinstance(content, str) or not content: | |
| continue | |
| norm.append({"role": _chat_role(m), "content": content}) | |
| if not norm: return None | |
| try: | |
| return tok.apply_chat_template(norm, tokenize=False, add_generation_prompt=add_generation_prompt) | |
| except Exception: | |
| return _fallback_chat_template(norm, add_generation_prompt) | |
| for a, b in (("prompt", "response"), ("instruction", "output"), ("question", "answer")): | |
| if isinstance(ex.get(a), str) and isinstance(ex.get(b), str): | |
| return f"User: {ex[a]}\nAssistant: {ex[b]}" | |
| return None | |
| def _parse_dataset_ref(ds_name: str): | |
| split = "train" | |
| ref = ds_name | |
| if "@" in ref: | |
| ref, split = ref.rsplit("@", 1) | |
| split = split or "train" | |
| if ":" in ref: | |
| base, config = ref.split(":", 1) | |
| else: | |
| base, config = ref, None | |
| return base, config, split | |
| def _open_stream_one(ds_name: str, seed: int, streaming: bool = True): | |
| dc = DownloadConfig(max_retries=5, use_etag=True, resume_download=True) | |
| base, config, split = _parse_dataset_ref(ds_name) | |
| if not streaming: | |
| print(f"[download] Downloading {ds_name} (non-streaming)...") | |
| if base == "json": | |
| data_files = {"train": config} | |
| ds = load_dataset("json", data_files=data_files, split=split, streaming=streaming, download_config=dc) | |
| else: | |
| ds = load_dataset(base, config, split=split, streaming=streaming, download_config=dc) if config else \ | |
| load_dataset(base, split=split, streaming=streaming, download_config=dc) | |
| if streaming: | |
| return iter(ds.shuffle(buffer_size=1000, seed=seed)) | |
| else: | |
| print(f"[download] Got {len(ds):,} examples. Shuffling...") | |
| ds = ds.shuffle(seed=seed) | |
| return iter(ds) | |
| def token_stream(ds_names: str, target: int, seed: int = 42, | |
| chat: bool = False, chat_messages_key: str = "messages", | |
| sft_add_generation_prompt: bool = False, dataset_field_text: str = "text", | |
| streaming: bool = True, use_hot_config: bool = True): | |
| if use_hot_config: | |
| ds_names = get_hot_datasets(ds_names) # HOT LOAD | |
| raw = [s.strip() for s in ds_names.split(",") if s.strip()] | |
| if not raw: return | |
| # Weighted interleave across sources, with an online quality router on top. | |
| # Base weights express policy; the router learns which sources yield bounded, | |
| # clean, useful examples instead of rewarding giant records for token volume. | |
| sources, weights = [], [] | |
| for s in raw: | |
| w = 1.0 | |
| head, sep, tail = s.rpartition("|") | |
| if sep: | |
| try: | |
| w = float(tail); s = head | |
| except ValueError: | |
| pass | |
| sources.append(s); weights.append(max(w, 0.0)) | |
| if sum(weights) <= 0: | |
| weights = [1.0] * len(sources) | |
| try: | |
| max_example_tokens = int(os.environ.get("AGILLM_MAX_EXAMPLE_TOKENS", "4096") or 0) | |
| except Exception: | |
| max_example_tokens = 4096 | |
| max_example_tokens = max(0, max_example_tokens) | |
| _rng = random.Random(seed) | |
| its = [None] * len(sources) | |
| emitted = 0 | |
| fail_counts = [0] * len(sources) | |
| disabled_until = [0.0] * len(sources) | |
| last_retry_log = [0.0] * len(sources) | |
| backoff_base = 2.0 | |
| max_cooldown = float(os.environ.get("AGILLM_STREAM_SOURCE_MAX_COOLDOWN_SEC", "300") or 300) | |
| fatal_cooldown = float(os.environ.get("AGILLM_STREAM_SOURCE_FATAL_COOLDOWN_SEC", "1800") or 1800) | |
| fatal_errors = {"DataFilesNotFoundError", "ArrowInvalid", "CastError", "FileNotFoundError"} | |
| router_enabled = str(os.environ.get("AGILLM_DATASET_NN_ROUTER", "1")).lower() not in {"0", "false", "off", "no"} | |
| router_state_path = Path(os.environ.get("AGILLM_DATASET_ROUTER_STATE", "/workspace/agillm_dataset_router_state.json")) | |
| router_explore = max(0.0, min(float(os.environ.get("AGILLM_DATASET_ROUTER_EXPLORE", "0.03") or 0.03), 0.50)) | |
| router_lr = max(0.0, min(float(os.environ.get("AGILLM_DATASET_ROUTER_LR", "0.03") or 0.03), 0.20)) | |
| router_min_score = max(0.01, min(float(os.environ.get("AGILLM_DATASET_ROUTER_MIN_SCORE", "0.05") or 0.05), 1.0)) | |
| router_sharpness = max(1.0, min(float(os.environ.get("AGILLM_DATASET_ROUTER_SHARPNESS", "3.0") or 3.0), 8.0)) | |
| router_log_sec = max(30.0, float(os.environ.get("AGILLM_DATASET_ROUTER_LOG_SEC", "300") or 300)) | |
| router_save_sec = max(10.0, float(os.environ.get("AGILLM_DATASET_ROUTER_SAVE_SEC", "60") or 60)) | |
| router_target_tokens = max(64.0, float(os.environ.get("AGILLM_DATASET_ROUTER_TARGET_TOKENS", str(max(512, min(max_example_tokens or 4096, 2048)))) or 2048)) | |
| router_min_quality = max(0.0, min(1.0, float(os.environ.get("AGILLM_DATASET_ROUTER_MIN_QUALITY", "0.45") or 0.45))) | |
| router_last_log = 0.0 | |
| router_last_save = 0.0 | |
| def _env_bool(name, default=False): | |
| return str(os.environ.get(name, "1" if default else "0")).strip().lower() not in {"", "0", "false", "off", "no"} | |
| def _env_float(name, default, lo=None, hi=None): | |
| try: | |
| val = float(os.environ.get(name, str(default)) or default) | |
| except Exception: | |
| val = float(default) | |
| if lo is not None: | |
| val = max(float(lo), val) | |
| if hi is not None: | |
| val = min(float(hi), val) | |
| return val | |
| agent_enabled = _env_bool("AGILLM_DATASET_AGENT_ROUTER", False) | |
| agent_timeout = _env_float("AGILLM_DATASET_AGENT_TIMEOUT_SEC", 8.0, 1.0, 60.0) | |
| agent_min_interval = _env_float("AGILLM_DATASET_AGENT_MIN_INTERVAL_SEC", 600.0, 30.0, 86400.0) | |
| agent_source_interval = _env_float("AGILLM_DATASET_AGENT_SOURCE_INTERVAL_SEC", 900.0, 30.0, 86400.0) | |
| agent_fail_threshold = int(_env_float("AGILLM_DATASET_AGENT_FAILS", 2.0, 1.0, 50.0)) | |
| agent_min_pulls = int(_env_float("AGILLM_DATASET_AGENT_MIN_PULLS", 4.0, 1.0, 1000.0)) | |
| agent_err_threshold = _env_float("AGILLM_DATASET_AGENT_ERR_EMA", 0.18, 0.01, 1.0) | |
| agent_empty_threshold = _env_float("AGILLM_DATASET_AGENT_EMPTY_EMA", 0.20, 0.01, 1.0) | |
| agent_latency_threshold = _env_float("AGILLM_DATASET_AGENT_LATENCY_SEC", 20.0, 1.0, 600.0) | |
| agent_min_conf = _env_float("AGILLM_DATASET_AGENT_MIN_CONF", 0.25, 0.0, 1.0) | |
| agent_default_penalty = _env_float("AGILLM_DATASET_AGENT_PENALTY", 0.35, 0.01, 1.0) | |
| agent_default_cooldown = _env_float("AGILLM_DATASET_AGENT_COOLDOWN_SEC", 900.0, 30.0, 86400.0) | |
| agent_disable_sec = _env_float("AGILLM_DATASET_AGENT_DISABLE_SEC", 21600.0, 60.0, 604800.0) | |
| agent_last_call = 0.0 | |
| def _sigmoid(x): | |
| if x < -40.0: return 0.0 | |
| if x > 40.0: return 1.0 | |
| return 1.0 / (1.0 + math.exp(-x)) | |
| def _load_router_state(): | |
| default_weights = [-0.15, 0.85, 1.40, -2.00, -0.25, 0.90, -2.50, 2.40, -3.00, -2.80, -1.60, -0.80] | |
| default = { | |
| "schema": "agillm.dataset_router.v2", | |
| "updated_utc": "", | |
| "weights": list(default_weights), | |
| "sources": {}, | |
| "agent": {}, | |
| } | |
| try: | |
| if router_state_path.exists(): | |
| loaded = json.loads(router_state_path.read_text()) | |
| if isinstance(loaded, dict): | |
| default.update({k: loaded.get(k, default[k]) for k in default}) | |
| if not isinstance(default.get("sources"), dict): | |
| default["sources"] = {} | |
| if default.get("schema") != "agillm.dataset_router.v2": | |
| default["schema"] = "agillm.dataset_router.v2" | |
| default["weights"] = list(default_weights) | |
| if not isinstance(default.get("weights"), list) or len(default["weights"]) != len(default_weights): | |
| default["weights"] = list(default_weights) | |
| except Exception as exc: | |
| print(f"[dataset-router] warning: could not load {router_state_path}: {exc}", flush=True) | |
| return default | |
| router = _load_router_state() | |
| router.setdefault("agent", {}) | |
| try: | |
| agent_last_call = float(router["agent"].get("last_call", 0.0) or 0.0) | |
| except Exception: | |
| agent_last_call = 0.0 | |
| def _source_state(src): | |
| st = router.setdefault("sources", {}).setdefault(src, {}) | |
| st.setdefault("ok_ema", 0.55) | |
| st.setdefault("err_ema", 0.05) | |
| st.setdefault("lat_ema", 1.0) | |
| st.setdefault("tok_ema", 256.0) | |
| st.setdefault("token_fit_ema", 0.50) | |
| st.setdefault("quality_ema", 0.65) | |
| st.setdefault("replacement_ema", 0.0) | |
| st.setdefault("control_ema", 0.0) | |
| st.setdefault("repeat_ema", 0.0) | |
| st.setdefault("short_ema", 0.05) | |
| st.setdefault("empty_ema", 0.05) | |
| st.setdefault("pulls", 0) | |
| st.setdefault("tokens", 0) | |
| st.setdefault("errors", 0) | |
| st.setdefault("empty", 0) | |
| st.setdefault("last_ok", 0.0) | |
| st.setdefault("last_error", "") | |
| st.setdefault("last_score", 0.5) | |
| st.setdefault("last_quality", 0.65) | |
| st.setdefault("agent_score_mult", 1.0) | |
| st.setdefault("agent_penalty_until", 0.0) | |
| st.setdefault("agent_last_check", 0.0) | |
| st.setdefault("agent_last_action", "") | |
| st.setdefault("agent_last_reason", "") | |
| st.setdefault("agent_last_error", "") | |
| return st | |
| for src in sources: | |
| _source_state(src) | |
| def _router_features(i, now): | |
| total_w = max(sum(weights), 1e-9) | |
| base = max(weights[i], 0.0) / total_w | |
| st = _source_state(sources[i]) | |
| return [ | |
| 1.0, | |
| min(1.0, base * len(weights)), | |
| float(st.get("ok_ema", 0.55)), | |
| float(st.get("err_ema", 0.05)), | |
| min(1.0, float(st.get("lat_ema", 1.0)) / 15.0), | |
| float(st.get("token_fit_ema", 0.50)), | |
| float(st.get("empty_ema", 0.05)), | |
| float(st.get("quality_ema", 0.65)), | |
| float(st.get("replacement_ema", 0.0)), | |
| float(st.get("control_ema", 0.0)), | |
| float(st.get("repeat_ema", 0.0)), | |
| float(st.get("short_ema", 0.05)), | |
| ] | |
| def _router_score(i, now): | |
| if not router_enabled: | |
| return 1.0 | |
| ws = router.get("weights") or [] | |
| feats = _router_features(i, now) | |
| z = sum(float(w) * float(f) for w, f in zip(ws, feats)) | |
| score = max(router_min_score, min(1.0, _sigmoid(z))) | |
| st = _source_state(sources[i]) | |
| try: | |
| until = float(st.get("agent_penalty_until", 0.0) or 0.0) | |
| mult = max(0.01, min(2.0, float(st.get("agent_score_mult", 1.0) or 1.0))) | |
| except Exception: | |
| until, mult = 0.0, 1.0 | |
| if until > now: | |
| score = max(router_min_score, min(1.0, score * mult)) | |
| elif until or mult != 1.0: | |
| st["agent_score_mult"] = 1.0 | |
| st["agent_penalty_until"] = 0.0 | |
| st["last_score"] = score | |
| return score | |
| def _save_router_state(force=False): | |
| nonlocal router_last_save | |
| now = time.time() | |
| if not force and now - router_last_save < router_save_sec: | |
| return | |
| router_last_save = now | |
| try: | |
| router["updated_utc"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(now)) | |
| tmp = router_state_path.with_suffix(router_state_path.suffix + ".tmp") | |
| tmp.parent.mkdir(parents=True, exist_ok=True) | |
| tmp.write_text(json.dumps(router, indent=2, sort_keys=True) + "\n") | |
| tmp.replace(router_state_path) | |
| except Exception as exc: | |
| print(f"[dataset-router] warning: could not save {router_state_path}: {exc}", flush=True) | |
| def _agent_read_secret(env_names, paths): | |
| for name in env_names: | |
| val = os.environ.get(name, "") | |
| if val.strip(): | |
| return val.strip() | |
| for raw_path in paths: | |
| try: | |
| p = Path(raw_path).expanduser() | |
| if p.exists(): | |
| val = p.read_text(errors="ignore").strip() | |
| if val: | |
| return val | |
| except Exception: | |
| pass | |
| return "" | |
| def _agent_provider_key_model(): | |
| pref = str(os.environ.get("AGILLM_DATASET_AGENT_PROVIDER", "auto") or "auto").strip().lower() | |
| deepseek_key = _agent_read_secret( | |
| ("DEEPSEEK_API_KEY", "AGILLM_DEEPSEEK_API_KEY"), | |
| ( | |
| "/root/.config/agillm/deepseek_api_key", | |
| "/workspace/private/deepseek_api_key", | |
| "/workspace/agillm_private/deepseek_api_key", | |
| ), | |
| ) | |
| openrouter_key = _agent_read_secret( | |
| ("OPENROUTER_API_KEY", "AGILLM_OPENROUTER_API_KEY"), | |
| ( | |
| "/root/.config/agillm/openrouter_api_key", | |
| "/workspace/private/openrouter_api_key", | |
| "/workspace/agillm_private/openrouter_api_key", | |
| ), | |
| ) | |
| deepseek_model = os.environ.get("AGILLM_DATASET_AGENT_DEEPSEEK_MODEL", "deepseek-chat") | |
| openrouter_model = os.environ.get("AGILLM_DATASET_AGENT_OPENROUTER_MODEL", "deepseek/deepseek-chat-v3-0324") | |
| if pref == "deepseek": | |
| return "deepseek", deepseek_key, deepseek_model, "configured" if deepseek_key else "missing-key" | |
| if pref == "openrouter": | |
| return "openrouter", openrouter_key, openrouter_model, "configured" if openrouter_key else "missing-key" | |
| if deepseek_key: | |
| return "deepseek", deepseek_key, deepseek_model, "configured" | |
| if openrouter_key: | |
| return "openrouter", openrouter_key, openrouter_model, "configured" | |
| return "auto", "", "", "missing-key" | |
| def _agent_extract_json(text): | |
| text = str(text or "").strip() | |
| if not text: | |
| return {} | |
| try: | |
| obj = json.loads(text) | |
| return obj if isinstance(obj, dict) else {} | |
| except Exception: | |
| pass | |
| start, end = text.find("{"), text.rfind("}") | |
| if start >= 0 and end > start: | |
| try: | |
| obj = json.loads(text[start:end + 1]) | |
| return obj if isinstance(obj, dict) else {} | |
| except Exception: | |
| return {} | |
| return {} | |
| def _agent_call(provider, key, model, payload): | |
| import urllib.error | |
| import urllib.request | |
| if provider == "deepseek": | |
| url = "https://api.deepseek.com/chat/completions" | |
| headers = {"Authorization": "Bearer " + key, "Content-Type": "application/json"} | |
| elif provider == "openrouter": | |
| url = "https://openrouter.ai/api/v1/chat/completions" | |
| headers = { | |
| "Authorization": "Bearer " + key, | |
| "Content-Type": "application/json", | |
| "HTTP-Referer": "https://join.opentransformers.online", | |
| "X-Title": "AGILLM dataset router", | |
| } | |
| else: | |
| return False, "unknown_provider" | |
| system = ( | |
| "You are a dataset routing policy agent for an active neural-network training run. " | |
| "Return compact JSON only. You may advise rerouting, cooldown, penalizing, disabling, keeping, or recovering a dataset source. " | |
| "Never create, rewrite, summarize, or transform training samples. " | |
| "Allowed actions: keep, penalize, cooldown, disable, recover. " | |
| "Use score_multiplier between 0.01 and 2.0 and cooldown_sec as seconds." | |
| ) | |
| body = { | |
| "model": model, | |
| "messages": [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": json.dumps(payload, sort_keys=True)}, | |
| ], | |
| "temperature": 0, | |
| "max_tokens": 180, | |
| } | |
| data = json.dumps(body).encode("utf-8") | |
| req = urllib.request.Request(url, data=data, headers=headers, method="POST") | |
| try: | |
| with urllib.request.urlopen(req, timeout=agent_timeout) as resp: | |
| raw = resp.read(32768).decode("utf-8", errors="replace") | |
| parsed = json.loads(raw) | |
| content = (((parsed.get("choices") or [{}])[0].get("message") or {}).get("content") or "") | |
| if not content and isinstance(parsed.get("output"), str): | |
| content = parsed["output"] | |
| return True, content | |
| except urllib.error.HTTPError as exc: | |
| return False, f"HTTP{getattr(exc, 'code', 'error')}" | |
| except Exception as exc: | |
| return False, type(exc).__name__ | |
| def _agent_maybe_advise(i, event): | |
| nonlocal agent_last_call | |
| if not agent_enabled or i is None: | |
| return | |
| now = time.time() | |
| st = _source_state(sources[i]) | |
| pulls = int(st.get("pulls", 0)) | |
| errors = int(st.get("errors", 0)) | |
| if pulls < agent_min_pulls and errors < agent_fail_threshold: | |
| return | |
| bad_enough = ( | |
| fail_counts[i] >= agent_fail_threshold | |
| or errors >= agent_fail_threshold | |
| or float(st.get("err_ema", 0.0)) >= agent_err_threshold | |
| or float(st.get("empty_ema", 0.0)) >= agent_empty_threshold | |
| or float(st.get("lat_ema", 0.0)) >= agent_latency_threshold | |
| ) | |
| if not bad_enough: | |
| return | |
| if now - agent_last_call < agent_min_interval: | |
| return | |
| if now - float(st.get("agent_last_check", 0.0) or 0.0) < agent_source_interval: | |
| return | |
| provider, key, model, status = _agent_provider_key_model() | |
| if not key: | |
| router.setdefault("agent", {})["last_status"] = status | |
| st["agent_last_check"] = now | |
| st["agent_last_error"] = status | |
| _save_router_state(force=True) | |
| return | |
| st["agent_last_check"] = now | |
| router.setdefault("agent", {})["last_call"] = now | |
| router["agent"]["last_provider"] = provider | |
| router["agent"]["last_model"] = model | |
| agent_last_call = now | |
| payload = { | |
| "source_index": i, | |
| "source": sources[i], | |
| "event": str(event or "failure")[:120], | |
| "policy": "reroute/cooldown only; never generate or modify data", | |
| "stats": { | |
| "pulls": pulls, | |
| "errors": errors, | |
| "empty": int(st.get("empty", 0)), | |
| "fail_count": int(fail_counts[i]), | |
| "ok_ema": float(st.get("ok_ema", 0.0)), | |
| "err_ema": float(st.get("err_ema", 0.0)), | |
| "empty_ema": float(st.get("empty_ema", 0.0)), | |
| "lat_ema": float(st.get("lat_ema", 0.0)), | |
| "tok_ema": float(st.get("tok_ema", 0.0)), | |
| "token_fit_ema": float(st.get("token_fit_ema", 0.0)), | |
| "quality_ema": float(st.get("quality_ema", 0.0)), | |
| "replacement_ema": float(st.get("replacement_ema", 0.0)), | |
| "control_ema": float(st.get("control_ema", 0.0)), | |
| "repeat_ema": float(st.get("repeat_ema", 0.0)), | |
| "router_score": float(st.get("last_score", 0.5)), | |
| "disabled_for_sec": max(0.0, float(disabled_until[i]) - now), | |
| "agent_score_mult": float(st.get("agent_score_mult", 1.0) or 1.0), | |
| }, | |
| "return_schema": { | |
| "action": "keep|penalize|cooldown|disable|recover", | |
| "score_multiplier": 0.35, | |
| "cooldown_sec": 900, | |
| "confidence": 0.5, | |
| "reason": "short reason", | |
| }, | |
| } | |
| ok, content = _agent_call(provider, key, model, payload) | |
| if not ok: | |
| st["agent_last_error"] = str(content)[:120] | |
| print(f"[dataset-agent] provider={provider} model={model} src={i}:{sources[i][:42]} error={content}", flush=True) | |
| _save_router_state(force=True) | |
| return | |
| advice = _agent_extract_json(content) | |
| action = str(advice.get("action", "keep") or "keep").strip().lower() | |
| if action not in {"keep", "penalize", "cooldown", "disable", "recover"}: | |
| action = "keep" | |
| try: | |
| confidence = max(0.0, min(1.0, float(advice.get("confidence", 0.0) or 0.0))) | |
| except Exception: | |
| confidence = 0.0 | |
| if confidence < agent_min_conf: | |
| action = "keep" | |
| try: | |
| mult = max(0.01, min(2.0, float(advice.get("score_multiplier", agent_default_penalty) or agent_default_penalty))) | |
| except Exception: | |
| mult = agent_default_penalty | |
| try: | |
| cooldown_sec = max(0.0, float(advice.get("cooldown_sec", agent_default_cooldown) or agent_default_cooldown)) | |
| except Exception: | |
| cooldown_sec = agent_default_cooldown | |
| reason = str(advice.get("reason", "") or "")[:180] | |
| if action == "recover": | |
| st["agent_score_mult"] = 1.0 | |
| st["agent_penalty_until"] = 0.0 | |
| disabled_until[i] = 0.0 | |
| elif action == "penalize": | |
| st["agent_score_mult"] = min(float(st.get("agent_score_mult", 1.0) or 1.0), mult) | |
| st["agent_penalty_until"] = max(float(st.get("agent_penalty_until", 0.0) or 0.0), now + max(cooldown_sec, agent_default_cooldown)) | |
| elif action == "cooldown": | |
| st["agent_score_mult"] = min(float(st.get("agent_score_mult", 1.0) or 1.0), mult) | |
| until = now + max(cooldown_sec, agent_default_cooldown) | |
| st["agent_penalty_until"] = max(float(st.get("agent_penalty_until", 0.0) or 0.0), until) | |
| disabled_until[i] = max(disabled_until[i], until) | |
| elif action == "disable": | |
| st["agent_score_mult"] = min(float(st.get("agent_score_mult", 1.0) or 1.0), min(mult, agent_default_penalty)) | |
| until = now + max(cooldown_sec, agent_disable_sec) | |
| st["agent_penalty_until"] = max(float(st.get("agent_penalty_until", 0.0) or 0.0), until) | |
| disabled_until[i] = max(disabled_until[i], until) | |
| st["agent_last_action"] = action | |
| st["agent_last_reason"] = reason | |
| st["agent_last_error"] = "" | |
| router.setdefault("agent", {})["last_status"] = "ok" | |
| _save_router_state(force=True) | |
| print( | |
| f"[dataset-agent] provider={provider} model={model} src={i}:{sources[i][:42]} " | |
| f"event={str(event)[:40]} action={action} mult={mult:.2f} cooldown={cooldown_sec:.0f}s conf={confidence:.2f} reason={reason}", | |
| flush=True, | |
| ) | |
| def _score_text_sample(text, token_count): | |
| preview = str(text or "")[:65536] | |
| n = max(1, len(preview)) | |
| repl = preview.count("\ufffd") / n | |
| control = sum(1 for ch in preview if ord(ch) < 32 and ch not in "\n\r\t") / n | |
| long_runs = 0 | |
| run = 1 | |
| prev = "" | |
| for ch in preview: | |
| if ch == prev: | |
| run += 1 | |
| else: | |
| if run >= 12: | |
| long_runs += run | |
| prev = ch | |
| run = 1 | |
| if run >= 12: | |
| long_runs += run | |
| repeat = long_runs / n | |
| whitespace = sum(1 for ch in preview if ch.isspace()) / n | |
| alpha = sum(1 for ch in preview if ch.isalpha()) / n | |
| digit = sum(1 for ch in preview if ch.isdigit()) / n | |
| tok = max(0.0, float(token_count or 0.0)) | |
| token_fit = max(0.0, min(1.0, 1.0 - abs(tok - router_target_tokens) / max(router_target_tokens, 1.0))) | |
| short = 1.0 if tok < min(128.0, router_target_tokens * 0.25) else 0.0 | |
| quality = 1.0 | |
| quality -= min(0.55, repl * 18.0) | |
| quality -= min(0.40, control * 28.0) | |
| quality -= min(0.35, repeat * 7.0) | |
| if whitespace < 0.04 or whitespace > 0.55: | |
| quality -= 0.12 | |
| if alpha < 0.18 and digit > 0.35: | |
| quality -= 0.16 | |
| if tok < 32: | |
| quality -= 0.35 | |
| elif tok < 128: | |
| quality -= 0.12 | |
| quality = max(0.0, min(1.0, quality)) | |
| return quality, token_fit, repl, control, repeat, short | |
| def _router_update(i, label, feat, token_count=0, latency=0.0, err="", empty=False, quality=None, token_fit=None, replacement_rate=0.0, control_rate=0.0, repeat_rate=0.0, short=0.0): | |
| if i is None: | |
| return | |
| st = _source_state(sources[i]) | |
| try: | |
| label = max(0.0, min(1.0, float(label))) | |
| except Exception: | |
| label = 0.0 | |
| alpha = 0.04 | |
| q = float(st.get("quality_ema", 0.65) if quality is None else max(0.0, min(1.0, float(quality)))) | |
| fit = float(st.get("token_fit_ema", 0.50) if token_fit is None else max(0.0, min(1.0, float(token_fit)))) | |
| replacement_rate = max(0.0, min(1.0, float(replacement_rate or 0.0))) | |
| control_rate = max(0.0, min(1.0, float(control_rate or 0.0))) | |
| repeat_rate = max(0.0, min(1.0, float(repeat_rate or 0.0))) | |
| short = max(0.0, min(1.0, float(short or 0.0))) | |
| st["pulls"] = int(st.get("pulls", 0)) + 1 | |
| st["ok_ema"] = (1.0 - alpha) * float(st.get("ok_ema", 0.55)) + alpha * label | |
| st["err_ema"] = (1.0 - alpha) * float(st.get("err_ema", 0.05)) + alpha * (1.0 - label) | |
| st["lat_ema"] = (1.0 - alpha) * float(st.get("lat_ema", 1.0)) + alpha * max(float(latency or 0.0), 0.0) | |
| st["tok_ema"] = (1.0 - alpha) * float(st.get("tok_ema", 256.0)) + alpha * max(float(token_count or 0.0), 0.0) | |
| st["token_fit_ema"] = (1.0 - alpha) * float(st.get("token_fit_ema", 0.50)) + alpha * fit | |
| st["quality_ema"] = (1.0 - alpha) * float(st.get("quality_ema", 0.65)) + alpha * q | |
| st["replacement_ema"] = (1.0 - alpha) * float(st.get("replacement_ema", 0.0)) + alpha * replacement_rate | |
| st["control_ema"] = (1.0 - alpha) * float(st.get("control_ema", 0.0)) + alpha * control_rate | |
| st["repeat_ema"] = (1.0 - alpha) * float(st.get("repeat_ema", 0.0)) + alpha * repeat_rate | |
| st["short_ema"] = (1.0 - alpha) * float(st.get("short_ema", 0.05)) + alpha * short | |
| st["empty_ema"] = (1.0 - alpha) * float(st.get("empty_ema", 0.05)) + alpha * (1.0 if empty else 0.0) | |
| st["last_quality"] = q | |
| if label >= 0.5: | |
| st["tokens"] = int(st.get("tokens", 0)) + int(token_count or 0) | |
| st["last_ok"] = time.time() | |
| st["last_error"] = "" | |
| else: | |
| st["errors"] = int(st.get("errors", 0)) + 1 | |
| st["last_error"] = str(err or "bad_sample")[:120] | |
| if empty: | |
| st["empty"] = int(st.get("empty", 0)) + 1 | |
| if router_enabled and feat and router_lr > 0: | |
| pred = _sigmoid(sum(float(w) * float(f) for w, f in zip(router["weights"], feat))) | |
| grad = label - pred | |
| router["weights"] = [max(-8.0, min(8.0, float(w) + router_lr * grad * float(f))) for w, f in zip(router["weights"], feat)] | |
| _save_router_state(force=(label < 0.5 or int(st.get("pulls", 0)) <= 3 or (int(st.get("pulls", 0)) % 25 == 0))) | |
| def _choose_source(available, now): | |
| if not router_enabled or _rng.random() < router_explore: | |
| return _rng.choices(available, weights=[weights[i] for i in available])[0] | |
| eff = [] | |
| for i in available: | |
| score = _router_score(i, now) | |
| eff.append(max(1e-9, weights[i] * (score ** router_sharpness))) | |
| if sum(eff) <= 0: | |
| eff = [weights[i] for i in available] | |
| return _rng.choices(available, weights=eff)[0] | |
| agent_provider, agent_key, agent_model, agent_status = _agent_provider_key_model() | |
| if not agent_enabled: | |
| agent_desc = "off" | |
| elif agent_key: | |
| agent_desc = f"{agent_provider}:{agent_model}" | |
| else: | |
| agent_desc = f"{agent_provider}:missing-key" | |
| print( | |
| f"[dataset-router] nn={'on' if router_enabled else 'off'} explore={router_explore:.3f} " | |
| f"agent={agent_desc} state={router_state_path} sources={len(sources)}", | |
| flush=True, | |
| ) | |
| while emitted < target: | |
| now = time.time() | |
| available = [i for i, w in enumerate(weights) if w > 0.0 and disabled_until[i] <= now] | |
| if not available: | |
| next_ready = min(disabled_until) if disabled_until else now + 1.0 | |
| sleep_s = max(1.0, min(30.0, next_ready - now)) | |
| print(f"[stream-retry] all sources cooling down, sleeping {sleep_s:.1f}s", flush=True) | |
| time.sleep(sleep_s) | |
| continue | |
| if router_enabled and now - router_last_log >= router_log_sec: | |
| rows = [] | |
| for i in range(len(sources)): | |
| st = _source_state(sources[i]) | |
| rows.append((float(st.get("last_score", _router_score(i, now))), i, st)) | |
| rows.sort(reverse=True) | |
| msg = "; ".join( | |
| f"{i}:{sources[i][:36]} score={score:.2f} q={st.get('quality_ema', 0):.2f} fit={st.get('token_fit_ema', 0):.2f} ok={st.get('ok_ema', 0):.2f} err={st.get('err_ema', 0):.2f} tok={st.get('tok_ema', 0):.0f}" | |
| for score, i, st in rows[:5] | |
| ) | |
| print(f"[dataset-router] {msg}", flush=True) | |
| router_last_log = now | |
| src_idx = _choose_source(available, now) | |
| feat = _router_features(src_idx, now) | |
| t0 = time.perf_counter() | |
| try: | |
| if its[src_idx] is None: | |
| its[src_idx] = _open_stream_one(sources[src_idx], seed + src_idx, streaming=streaming) | |
| ex = next(its[src_idx]) | |
| text = None | |
| if isinstance(ex, dict): | |
| if chat: | |
| text = _render_chat_text_from_ex(ex, chat_messages_key, sft_add_generation_prompt) | |
| if text is None: | |
| if dataset_field_text and isinstance(ex.get(dataset_field_text), str): | |
| text = ex[dataset_field_text] | |
| elif isinstance(ex.get("text"), str): | |
| text = ex["text"] | |
| if not isinstance(text, str) or not text.strip(): | |
| _router_update(src_idx, 0, feat, latency=time.perf_counter() - t0, err="empty_or_missing_text", empty=True) | |
| _agent_maybe_advise(src_idx, "empty_or_missing_text") | |
| continue | |
| if fail_counts[src_idx]: | |
| print(f"[stream-recover] {sources[src_idx]} recovered after {fail_counts[src_idx]} failures", flush=True) | |
| fail_counts[src_idx] = 0 | |
| disabled_until[src_idx] = 0.0 | |
| max_example_chars = int(os.environ.get("AGILLM_MAX_EXAMPLE_CHARS", str(max(8192, (max_example_tokens or 4096) * 8))) or 0) | |
| if max_example_chars and len(text) > max_example_chars: | |
| span_chars = max(1, len(text) - max_example_chars + 1) | |
| start_chars = _rng.randrange(span_chars) | |
| text = text[start_chars:start_chars + max_example_chars] | |
| enc = tok.encode(text) | |
| if EOS is not None and (len(enc) == 0 or enc[-1] != EOS): | |
| enc = enc + [EOS] | |
| if max_example_tokens and len(enc) > max_example_tokens: | |
| span = max(1, len(enc) - max_example_tokens + 1) | |
| start = _rng.randrange(span) | |
| enc = enc[start:start + max_example_tokens] | |
| if not enc: | |
| _router_update(src_idx, 0, feat, latency=time.perf_counter() - t0, err="empty_tokens", empty=True) | |
| _agent_maybe_advise(src_idx, "empty_tokens") | |
| continue | |
| quality, token_fit, replacement_rate, control_rate, repeat_rate, short = _score_text_sample(text, len(enc)) | |
| label = quality if quality >= router_min_quality else max(0.0, quality * 0.5) | |
| _router_update(src_idx, label, feat, token_count=len(enc), latency=time.perf_counter() - t0, quality=quality, token_fit=token_fit, replacement_rate=replacement_rate, control_rate=control_rate, repeat_rate=repeat_rate, short=short) | |
| for t in enc: | |
| yield t | |
| emitted += 1 | |
| if emitted >= target: | |
| _save_router_state(force=True) | |
| return | |
| except StopIteration: | |
| its[src_idx] = None # exhausted: reopen on next pick (stream cycles) | |
| except Exception as e: | |
| its[src_idx] = None | |
| fail_counts[src_idx] += 1 | |
| err = type(e).__name__ | |
| _router_update(src_idx, 0, feat, latency=time.perf_counter() - t0, err=err) | |
| cooldown = min(max_cooldown, backoff_base ** min(fail_counts[src_idx], 8)) | |
| if err in fatal_errors: | |
| cooldown = max(cooldown, fatal_cooldown) | |
| disabled_until[src_idx] = time.time() + cooldown | |
| _agent_maybe_advise(src_idx, err) | |
| if time.time() - last_retry_log[src_idx] > 15.0 or fail_counts[src_idx] <= 2: | |
| print( | |
| f"[stream-retry] {sources[src_idx]} error: {err}, " | |
| f"cooling {cooldown:.1f}s failures={fail_counts[src_idx]}", | |
| flush=True, | |
| ) | |
| last_retry_log[src_idx] = time.time() | |
| # ───────────────────────── ALiBi ───────────────────────── | |
| def _alibi_slopes(n_heads: int): | |
| def pow2slopes(n): | |
| start = 2 ** (-2 ** -(math.log2(n) - 3)) | |
| ratio = start | |
| return [start * (ratio ** i) for i in range(n)] | |
| if math.log2(n_heads).is_integer(): vals = pow2slopes(n_heads) | |
| else: | |
| closest = 2 ** math.floor(math.log2(n_heads)) | |
| vals = pow2slopes(closest) | |
| extra = pow2slopes(2 * closest) | |
| vals += extra[0::2][: n_heads - closest] | |
| return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1) | |
| def alibi_bias(n_heads: int, n_tokens: int): | |
| i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1) | |
| j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens) | |
| dist = (j - i).clamp_min(0) | |
| return -_alibi_slopes(n_heads) * dist | |
| class StructuredAttentionMask: | |
| """Symbolic attention rules for sublinear attention. | |
| Dense masks are O(T^2). This object carries the rule so sublinear attention can | |
| apply it only to the gathered local/anchor candidate keys: O(T * candidates). | |
| """ | |
| __slots__ = ("kind", "q_len", "k_len", "query_base", "block") | |
| def __init__(self, kind: str, q_len: int, k_len: int = None, query_base: int = 0, block: int = 1): | |
| self.kind = (kind or "none").lower() | |
| self.q_len = int(q_len) | |
| self.k_len = int(k_len if k_len is not None else q_len) | |
| self.query_base = int(query_base) | |
| self.block = max(1, int(block)) | |
| def to_dense(self, device=None, dtype=torch.float32): | |
| device = device or DEV | |
| if self.kind in {"none", "nat", "bidirectional", "unrestricted"}: | |
| return None | |
| q_pos = torch.arange(self.query_base, self.query_base + self.q_len, device=device, dtype=torch.long).view(self.q_len, 1) | |
| k_pos = torch.arange(self.k_len, device=device, dtype=torch.long).view(1, self.k_len) | |
| if self.kind == "causal": | |
| allow = k_pos <= q_pos | |
| elif self.kind in {"sat", "block_causal", "block-causal"}: | |
| allow = (k_pos // self.block) <= (q_pos // self.block) | |
| else: | |
| raise ValueError(f"unknown structured attention mask kind: {self.kind}") | |
| zeros = torch.zeros((self.q_len, self.k_len), device=device, dtype=dtype) | |
| neg = torch.full_like(zeros, float("-inf")) | |
| return torch.where(allow, zeros, neg).unsqueeze(0).unsqueeze(0) | |
| def _is_structured_attention_mask(mask) -> bool: | |
| return isinstance(mask, StructuredAttentionMask) | |
| def use_structured_masks(args=None, backend: str = None) -> bool: | |
| backend = (backend or getattr(args, "attn_backend", "") or "").lower() | |
| return backend == "sublinear" and not bool(getattr(args, "no_structured_masks", False)) | |
| # ───────────────────────── Model components ───────────────────────── | |
| class KVBuffer: | |
| """Preallocated K/V cache for decode. Replaces torch.cat-based growth. | |
| Layout matches MHA-internal head-major shape [B, H, T, d_k]. Caller sizes | |
| once; each ``append`` writes ``length:length+n`` slots in place and grows | |
| ``length``. ``view()`` returns slices of the live region so attention sees | |
| only filled positions. | |
| """ | |
| __slots__ = ("k", "v", "length", "capacity") | |
| def __init__( | |
| self, | |
| batch: int, | |
| heads: int, | |
| capacity: int, | |
| d_k: int, | |
| device, | |
| dtype, | |
| ): | |
| self.k = torch.empty(batch, heads, capacity, d_k, device=device, dtype=dtype) | |
| self.v = torch.empty(batch, heads, capacity, d_k, device=device, dtype=dtype) | |
| self.length = 0 | |
| self.capacity = capacity | |
| def append(self, k_new: torch.Tensor, v_new: torch.Tensor): | |
| n = k_new.size(2) | |
| end = self.length + n | |
| if end > self.capacity: | |
| raise RuntimeError( | |
| f"KVBuffer overflow: length={self.length} + n={n} > capacity={self.capacity}" | |
| ) | |
| self.k[:, :, self.length:end].copy_(k_new) | |
| self.v[:, :, self.length:end].copy_(v_new) | |
| self.length = end | |
| def view(self): | |
| return self.k[:, :, :self.length], self.v[:, :, :self.length] | |
| class TuneableAttentionMHA(nn.Module): | |
| def __init__( | |
| self, | |
| d: int, | |
| h: int, | |
| r: int, | |
| use_relpos: bool = True, | |
| attn_backend: str = DEFAULT_ATTN_BACKEND, | |
| sublinear_window: int = DEFAULT_SUBLINEAR_WINDOW, | |
| sublinear_stride: int = DEFAULT_SUBLINEAR_STRIDE, | |
| sublinear_max_anchors: int = DEFAULT_SUBLINEAR_MAX_ANCHORS, | |
| sublinear_chunk: int = DEFAULT_SUBLINEAR_CHUNK, | |
| sublinear_sinks: int = DEFAULT_SUBLINEAR_SINKS, | |
| sublinear_recent_anchors: int = DEFAULT_SUBLINEAR_RECENT_ANCHORS, | |
| sublinear_pooled_landmarks: bool = DEFAULT_SUBLINEAR_POOLED_LANDMARKS, | |
| tie_kv: bool = False, | |
| ): | |
| super().__init__() | |
| assert d % h == 0 | |
| self.h, self.dk, self.r = h, d // h, r | |
| self.use_relpos = use_relpos | |
| self.attn_backend = (attn_backend or "manual").lower() | |
| self.sublinear_window = max(1, int(sublinear_window)) | |
| self.sublinear_stride = max(0, int(sublinear_stride)) | |
| self.sublinear_max_anchors = max(0, int(sublinear_max_anchors)) | |
| self.sublinear_chunk = max(1, int(sublinear_chunk)) | |
| self.sublinear_sinks = max(0, int(sublinear_sinks)) | |
| recent = int(sublinear_recent_anchors) | |
| if recent < 0: | |
| recent = self.sublinear_max_anchors // 2 | |
| self.sublinear_recent_anchors = min(max(0, recent), self.sublinear_max_anchors) | |
| self.sublinear_pooled_landmarks = bool(sublinear_pooled_landmarks) | |
| # Exact n1 harvest: one fused QKV projection is mathematically the same | |
| # as three independent bias-free Linear(d, d) projections with their | |
| # weights stacked along out_features. | |
| # Q-K=V (arXiv 2606.04032): tie Key & Value into one shared projection. | |
| # For r>dk, reshape_heads==reshape_v so k_new IS v_new (exact) -> clean 50% KV-cache cut | |
| # and -33% qkv params. Gated; default off preserves the 3*d checkpoint layout. | |
| self.tie_kv = bool(tie_kv) | |
| self.qkv = nn.Linear(d, (2 if self.tie_kv else 3) * d, bias=False) | |
| self.U = nn.Parameter(torch.randn(self.dk, r)) | |
| nn.init.orthogonal_(self.U) | |
| self.proj = nn.Linear(h * self.dk, d, bias=False) | |
| self.drop = nn.Dropout(0.1) | |
| # Exact n1 harvest: for expansion ranks, (q @ U) @ (k @ U).T is | |
| # q @ (U @ U.T) @ k.T. This keeps score/cache width at d_k with no | |
| # quality change. Inference caches the metric and training recomputes | |
| # it so gradients through U are unchanged. | |
| self._metric_cache: Optional[torch.Tensor] = None | |
| self._metric_cache_ver: int = -1 | |
| self._metric_cache_param_id: int = -1 | |
| self._metric_cache_data_ptr: int = -1 | |
| self._metric_cache_shape: Tuple[int, int] = (-1, -1) | |
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |
| missing_keys, unexpected_keys, error_msgs): | |
| qkv_key = prefix + "qkv.weight" | |
| if qkv_key not in state_dict: | |
| qk = prefix + "q.weight" | |
| kk = prefix + "k.weight" | |
| vk = prefix + "v.weight" | |
| if qk in state_dict and kk in state_dict and vk in state_dict: | |
| fused = _cat_legacy_weight_blocks([state_dict[qk], state_dict[kk], state_dict[vk]]) | |
| if fused is not None: | |
| state_dict[qkv_key] = fused | |
| state_dict.pop(qk) | |
| state_dict.pop(kk) | |
| state_dict.pop(vk) | |
| return super()._load_from_state_dict( | |
| state_dict, prefix, local_metadata, strict, | |
| missing_keys, unexpected_keys, error_msgs, | |
| ) | |
| def _proj_qk(self, x): | |
| B, N, _ = x.shape | |
| return (x.view(B, N, self.h, self.dk).transpose(1, 2) @ self.U) | |
| def _reshape_v(self, x): | |
| B, N, _ = x.shape | |
| return x.view(B, N, self.h, self.dk).transpose(1, 2) | |
| def _reshape_heads(self, x): | |
| B, N, _ = x.shape | |
| return x.view(B, N, self.h, self.dk).transpose(1, 2) | |
| def _get_metric(self) -> torch.Tensor: | |
| if torch.is_grad_enabled(): | |
| return self.U @ self.U.T | |
| cur_ver = self.U._version | |
| cur_param_id = id(self.U) | |
| cur_data_ptr = int(self.U.data_ptr()) | |
| cur_shape = tuple(self.U.shape) | |
| cache = self._metric_cache | |
| if ( | |
| cache is None | |
| or cache.dtype != self.U.dtype | |
| or cache.device != self.U.device | |
| or self._metric_cache_ver != cur_ver | |
| or self._metric_cache_param_id != cur_param_id | |
| or self._metric_cache_data_ptr != cur_data_ptr | |
| or self._metric_cache_shape != cur_shape | |
| ): | |
| cache = (self.U @ self.U.T).detach() | |
| self._metric_cache = cache | |
| self._metric_cache_ver = cur_ver | |
| self._metric_cache_param_id = cur_param_id | |
| self._metric_cache_data_ptr = cur_data_ptr | |
| self._metric_cache_shape = cur_shape | |
| return cache | |
| def train(self, mode: bool = True): | |
| if mode: | |
| self._metric_cache = None | |
| self._metric_cache_ver = -1 | |
| self._metric_cache_param_id = -1 | |
| self._metric_cache_data_ptr = -1 | |
| self._metric_cache_shape = (-1, -1) | |
| return super().train(mode) | |
| def _structured_valid(self, attn_mask, q_pos, idx): | |
| if not _is_structured_attention_mask(attn_mask): | |
| return None | |
| kind = attn_mask.kind | |
| if kind in {"none", "nat", "bidirectional", "unrestricted"}: | |
| return torch.ones_like(idx, dtype=torch.bool) | |
| if kind == "causal": | |
| return idx <= q_pos[:, None] | |
| if kind in {"sat", "block_causal", "block-causal"}: | |
| block = max(1, int(attn_mask.block)) | |
| return (idx // block) <= (q_pos[:, None] // block) | |
| raise ValueError(f"unknown structured attention mask kind: {kind}") | |
| def _sublinear_anchor_positions(self, k_len: int, device): | |
| anchor_start = self.sublinear_stride - 1 | |
| if self.sublinear_stride <= 0 or self.sublinear_max_anchors <= 0 or anchor_start >= k_len: | |
| anchors = torch.empty(0, device=device, dtype=torch.long) | |
| else: | |
| all_anchors = torch.arange(anchor_start, k_len, self.sublinear_stride, device=device, dtype=torch.long) | |
| if all_anchors.numel() <= self.sublinear_max_anchors: | |
| anchors = all_anchors | |
| else: | |
| recent_budget = min(self.sublinear_recent_anchors, self.sublinear_max_anchors) | |
| span_budget = max(0, self.sublinear_max_anchors - recent_budget) | |
| parts = [] | |
| if span_budget > 0: | |
| span_sel = torch.linspace(0, all_anchors.numel() - 1, span_budget, device=device).round().long().unique() | |
| parts.append(all_anchors[span_sel]) | |
| if recent_budget > 0: | |
| parts.append(all_anchors[-recent_budget:]) | |
| anchors = torch.cat(parts).unique() if parts else torch.empty(0, device=device, dtype=torch.long) | |
| if self.sublinear_sinks > 0 and k_len > 0: | |
| sinks = torch.arange(min(self.sublinear_sinks, k_len), device=device, dtype=torch.long) | |
| anchors = torch.cat([sinks, anchors]).unique() if anchors.numel() else sinks | |
| return anchors | |
| def _sublinear_attention(self, q, k, v, attn_mask=None, rel_bias_tokens=None): | |
| """Local-window + landmark attention: O(N * (window + N/stride)).""" | |
| bsz, heads, q_len, _ = q.shape | |
| k_len = k.size(2) | |
| device = q.device | |
| query_base = max(0, k_len - q_len) | |
| outputs = [] | |
| scale = 1.0 / math.sqrt(self.dk) | |
| slopes = None | |
| if self.use_relpos and rel_bias_tokens is not None: | |
| slopes = _alibi_slopes(self.h).to(device=device, dtype=torch.float32) | |
| anchors = self._sublinear_anchor_positions(k_len, device) | |
| anchor_k = anchor_v = None | |
| if anchors.numel() and self.sublinear_pooled_landmarks and self.sublinear_stride > 1: | |
| # Optional pooled landmarks: each global anchor summarizes its stride segment. | |
| # This is off by default because it adds cumsum work; enable after benchmarking. | |
| ends = anchors + 1 | |
| starts = (ends - self.sublinear_stride).clamp_min(0) | |
| zero_k = k.new_zeros(k.size(0), k.size(1), 1, k.size(3)) | |
| zero_v = v.new_zeros(v.size(0), v.size(1), 1, v.size(3)) | |
| prefix_k = torch.cat([zero_k, k.cumsum(dim=2)], dim=2) | |
| prefix_v = torch.cat([zero_v, v.cumsum(dim=2)], dim=2) | |
| denom = (ends - starts).to(dtype=k.dtype).view(1, 1, -1, 1).clamp_min(1) | |
| anchor_k = (prefix_k[:, :, ends, :] - prefix_k[:, :, starts, :]) / denom | |
| anchor_v = (prefix_v[:, :, ends, :] - prefix_v[:, :, starts, :]) / denom | |
| offsets = torch.arange( | |
| -self.sublinear_window, | |
| self.sublinear_window + 1, | |
| device=device, | |
| dtype=torch.long, | |
| ) | |
| for q_start in range(0, q_len, self.sublinear_chunk): | |
| q_end = min(q_len, q_start + self.sublinear_chunk) | |
| cur = q_end - q_start | |
| q_pos = torch.arange(query_base + q_start, query_base + q_end, device=device, dtype=torch.long) | |
| local_raw = q_pos[:, None] + offsets[None, :] | |
| local_valid = (local_raw >= 0) & (local_raw < k_len) | |
| local_idx = local_raw.clamp(0, max(0, k_len - 1)) | |
| k_local = k[:, :, local_idx, :] | |
| v_local = v[:, :, local_idx, :] | |
| if anchors.numel(): | |
| anchor_idx = anchors.view(1, -1).expand(cur, -1) | |
| local_lo = (q_pos - self.sublinear_window).clamp_min(0).view(-1, 1) | |
| local_hi = (q_pos + self.sublinear_window).clamp_max(max(0, k_len - 1)).view(-1, 1) | |
| # Drop anchor copies already present in the local window; duplicates bias softmax mass. | |
| anchor_valid = (anchor_idx < local_lo) | (anchor_idx > local_hi) | |
| idx = torch.cat([local_idx, anchor_idx], dim=1) | |
| valid = torch.cat([local_valid, anchor_valid], dim=1) | |
| if anchor_k is not None and anchor_v is not None: | |
| k_anchor = anchor_k.unsqueeze(2).expand(-1, -1, cur, -1, -1) | |
| v_anchor = anchor_v.unsqueeze(2).expand(-1, -1, cur, -1, -1) | |
| else: | |
| k_anchor = k[:, :, anchor_idx, :] | |
| v_anchor = v[:, :, anchor_idx, :] | |
| k_sel = torch.cat([k_local, k_anchor], dim=-2) | |
| v_sel = torch.cat([v_local, v_anchor], dim=-2) | |
| else: | |
| idx = local_idx | |
| valid = local_valid | |
| k_sel = k_local | |
| v_sel = v_local | |
| structured_valid = self._structured_valid(attn_mask, q_pos, idx) | |
| if structured_valid is not None: | |
| valid = valid & structured_valid | |
| scores = (q[:, :, q_start:q_end, :].unsqueeze(-2) * k_sel).sum(dim=-1) * scale | |
| if slopes is not None: | |
| dist = (q_pos.view(1, 1, cur, 1) - idx.view(1, 1, cur, -1)).abs().to(torch.float32) | |
| scores = scores + (-slopes * dist).to(scores.dtype) | |
| if torch.is_tensor(attn_mask) and attn_mask.size(-1) == k_len and attn_mask.size(-2) >= q_end: | |
| mask_q = attn_mask[..., q_start:q_end, :] | |
| gather_idx = idx.view(1, 1, cur, -1).expand(mask_q.size(0), mask_q.size(1), cur, idx.size(1)) | |
| scores = scores + torch.gather(mask_q, -1, gather_idx) | |
| scores = scores.masked_fill(~valid.view(1, 1, cur, -1), float("-inf")) | |
| weights = torch.softmax(scores.float(), dim=-1).to(v.dtype) | |
| outputs.append((weights.unsqueeze(-1) * v_sel).sum(dim=-2)) | |
| return torch.cat(outputs, dim=2) | |
| def forward(self, x, mask=None, rel_bias_tokens=None, kv_cache=None, use_cache=False): | |
| if self.tie_kv: | |
| q_lin, kv_lin = self.qkv(x).chunk(2, dim=-1) | |
| k_lin = v_lin = kv_lin | |
| else: | |
| q_lin, k_lin, v_lin = self.qkv(x).chunk(3, dim=-1) | |
| if self.r > self.dk: | |
| q = self._reshape_heads(q_lin) @ self._get_metric() | |
| k_new = self._reshape_heads(k_lin) | |
| v_new = k_new if self.tie_kv else self._reshape_v(v_lin) | |
| else: | |
| q = self._proj_qk(q_lin) | |
| k_new = self._proj_qk(k_lin) | |
| v_new = self._reshape_v(v_lin) | |
| if kv_cache is None: | |
| k, v = k_new, v_new | |
| elif isinstance(kv_cache, KVBuffer): | |
| if use_cache: | |
| kv_cache.append(k_new, v_new) | |
| k, v = kv_cache.view() | |
| else: | |
| k, v = k_new, v_new | |
| else: | |
| k_cached, v_cached = kv_cache | |
| if use_cache: | |
| k = torch.cat([k_cached, k_new], dim=2) | |
| v = torch.cat([v_cached, v_new], dim=2) | |
| else: | |
| k, v = k_new, v_new | |
| attn_mask = mask | |
| if self.attn_backend != "sublinear" and _is_structured_attention_mask(attn_mask): | |
| attn_mask = attn_mask.to_dense(device=q.device, dtype=q.dtype) | |
| if self.attn_backend != "sublinear" and self.use_relpos and rel_bias_tokens is not None: | |
| rel = alibi_bias(self.h, rel_bias_tokens)[:, :, -q.size(2):, :].to(device=q.device, dtype=q.dtype) | |
| attn_mask = rel if attn_mask is None else attn_mask + rel | |
| if self.attn_backend == "sdpa" and attn_mask is not None and attn_mask.dtype != torch.bool and attn_mask.dtype != q.dtype: | |
| attn_mask = attn_mask.to(dtype=q.dtype) | |
| if self.attn_backend == "sdpa": | |
| try: | |
| z = F.scaled_dot_product_attention( | |
| q, k, v, | |
| attn_mask=attn_mask, | |
| dropout_p=0.0, | |
| scale=1.0 / math.sqrt(self.dk), | |
| ) | |
| except TypeError: | |
| # Older torch lacks the scale kwarg. Rescale q so SDPA's default sqrt(r) | |
| # denominator matches the historical AGILLM sqrt(d_k) denominator. | |
| q_scaled = q * math.sqrt(q.size(-1) / self.dk) | |
| z = F.scaled_dot_product_attention(q_scaled, k, v, attn_mask=attn_mask, dropout_p=0.0) | |
| elif self.attn_backend == "sublinear": | |
| z = self._sublinear_attention(q, k, v, attn_mask=attn_mask, rel_bias_tokens=rel_bias_tokens) | |
| else: | |
| att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) | |
| if attn_mask is not None: | |
| att = att + attn_mask | |
| z = att.softmax(-1).to(v.dtype) @ v | |
| z = z.transpose(1, 2).reshape(x.size(0), x.size(1), -1) | |
| out = self.drop(self.proj(z)) | |
| if not use_cache: | |
| return out | |
| new_kv = kv_cache if isinstance(kv_cache, KVBuffer) else (k, v) | |
| return out, new_kv | |
| class MoEFFN(nn.Module): | |
| def __init__(self, d: int, mlp_mult: int = 4, experts: int = 4, top_k: int = 1, | |
| shared_experts: int = 0, shared_mlp_mult: int = 0): | |
| super().__init__() | |
| self.d = int(d) | |
| self.mlp_mult = max(1, int(mlp_mult)) | |
| self.num_experts = max(1, int(experts)) | |
| self.top_k = min(max(1, int(top_k)), self.num_experts) | |
| hidden = self.mlp_mult * self.d | |
| self.router = nn.Linear(self.d, self.num_experts, bias=False) | |
| self.experts = nn.ModuleList([ | |
| nn.Sequential(nn.Linear(self.d, hidden), nn.ReLU(), nn.Linear(hidden, self.d)) | |
| for _ in range(self.num_experts) | |
| ]) | |
| # Shared experts (DeepSeek/ST-MoE style): always-on FFN added to the routed | |
| # output, giving every token a consistent fallback representation -> lower | |
| # routing variance, smoother optimization. Output layer is ZERO-INITIALISED so | |
| # the shared path is a no-op at step 0, making it mergeable into an existing | |
| # checkpoint without disruption (it then learns to contribute). | |
| self.num_shared = max(0, int(shared_experts)) | |
| if self.num_shared > 0: | |
| shidden = max(1, int(shared_mlp_mult) or self.mlp_mult) * self.d | |
| self.shared = nn.ModuleList([ | |
| nn.Sequential(nn.Linear(self.d, shidden), nn.ReLU(), nn.Linear(shidden, self.d)) | |
| for _ in range(self.num_shared) | |
| ]) | |
| for blk in self.shared: | |
| nn.init.zeros_(blk[2].weight); nn.init.zeros_(blk[2].bias) | |
| else: | |
| self.shared = None | |
| # Detached FFN input stashed each training forward; the router aux loss is | |
| # recomputed OUTSIDE the gradient-checkpoint boundary by _collect_moe_aux(). | |
| self.last_router_input = None | |
| # Inference-only expert streaming: block-stream can keep only router/shared | |
| # paths resident and page selected routed experts on demand. | |
| self.expert_stream = False | |
| self.expert_stream_empty_cache = True | |
| self.expert_stream_stats = {"loads": 0, "tokens": 0} | |
| def set_expert_stream(self, enabled: bool, empty_cache: bool = True): | |
| self.expert_stream = bool(enabled) | |
| self.expert_stream_empty_cache = bool(empty_cache) | |
| return self | |
| def _run_expert(self, expert, rows): | |
| if self.expert_stream and torch.is_tensor(rows) and rows.is_cuda: | |
| expert.to(rows.device) | |
| try: | |
| out = expert(rows) | |
| finally: | |
| expert.to("cpu") | |
| self.expert_stream_stats["loads"] = int(self.expert_stream_stats.get("loads", 0)) + 1 | |
| self.expert_stream_stats["tokens"] = int(self.expert_stream_stats.get("tokens", 0)) + int(rows.size(0)) | |
| if self.expert_stream_empty_cache and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return out | |
| return expert(rows) | |
| def _shared_out(self, flat): | |
| if self.shared is None: | |
| return 0.0 | |
| s = self.shared[0](flat) | |
| for blk in self.shared[1:]: | |
| s = s + blk(flat) | |
| return s | |
| def forward(self, x): | |
| orig_shape = x.shape | |
| flat = x.reshape(-1, orig_shape[-1]) | |
| if self.training: | |
| # Stash the detached input (no autograd graph) so the load-balance loss | |
| # can be recomputed after the block forward. Computing it here would run | |
| # without grad (checkpoint's no-grad first pass) or pin block activations | |
| # across the checkpoint boundary and blow up VRAM. | |
| self.last_router_input = flat.detach() | |
| router_in = flat.to(self.router.weight.dtype) if flat.dtype != self.router.weight.dtype else flat | |
| scores = self.router(router_in).float() | |
| if self.top_k == 1: | |
| probs = scores.softmax(dim=-1) | |
| chosen = probs.argmax(dim=-1) | |
| out = torch.zeros_like(flat) | |
| for expert_id, expert in enumerate(self.experts): | |
| mask = chosen == expert_id | |
| if not bool(mask.any()): | |
| continue | |
| gate = probs[mask, expert_id].to(flat.dtype).clamp_min(1e-6) | |
| # Keep the forward value equal to the selected expert while | |
| # sending a straight-through gradient into the top-1 router. | |
| gate_st = (gate / gate.detach()).unsqueeze(-1) | |
| out[mask] = self._run_expert(expert, flat[mask]) * gate_st | |
| if self.shared is not None: | |
| out = out + self._shared_out(flat) | |
| return out.reshape(orig_shape) | |
| vals, idx = torch.topk(scores, k=self.top_k, dim=-1) | |
| weights = vals.softmax(dim=-1).to(flat.dtype) | |
| out = torch.zeros_like(flat) | |
| for rank in range(self.top_k): | |
| chosen = idx[:, rank] | |
| weight = weights[:, rank].unsqueeze(-1) | |
| for expert_id, expert in enumerate(self.experts): | |
| rows = (chosen == expert_id).nonzero(as_tuple=False).flatten() | |
| if rows.numel() == 0: | |
| continue | |
| out.index_add_(0, rows, self._run_expert(expert, flat.index_select(0, rows)) * weight.index_select(0, rows)) | |
| if self.shared is not None: | |
| out = out + self._shared_out(flat) | |
| return out.reshape(orig_shape) | |
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |
| missing_keys, unexpected_keys, error_msgs): | |
| legacy = { | |
| "0.weight": "0.weight", | |
| "0.bias": "0.bias", | |
| "2.weight": "2.weight", | |
| "2.bias": "2.bias", | |
| } | |
| seeded = False | |
| for expert_idx, expert in enumerate(self.experts): | |
| expert_state = expert.state_dict() | |
| for legacy_suffix, expert_suffix in legacy.items(): | |
| src_key = prefix + legacy_suffix | |
| dst_key = prefix + f"experts.{expert_idx}." + expert_suffix | |
| src = state_dict.get(src_key) | |
| tgt = expert_state.get(expert_suffix) | |
| if dst_key not in state_dict and torch.is_tensor(src) and torch.is_tensor(tgt) and tuple(src.shape) == tuple(tgt.shape): | |
| state_dict[dst_key] = src | |
| seeded = True | |
| if seeded and prefix + "router.weight" not in state_dict: | |
| state_dict[prefix + "router.weight"] = self.router.weight.detach().clone() | |
| if seeded: | |
| for suffix in legacy: | |
| state_dict.pop(prefix + suffix, None) | |
| return super()._load_from_state_dict( | |
| state_dict, prefix, local_metadata, strict, | |
| missing_keys, unexpected_keys, error_msgs, | |
| ) | |
| def _collect_moe_aux(model, aux_coef=0.0, z_coef=0.0): | |
| """Sum and clear the MoE load-balance / router-z losses. | |
| Recomputes the router on the detached FFN input stashed during the forward, | |
| so it works with gradient checkpointing (router logits are available WITH grad | |
| here, outside the checkpointed region) and pins no block activations (the input | |
| is detached, so only router.weight receives gradient). Returns a scalar tensor | |
| to add to the loss before backward(), or 0.0 when disabled / nothing stashed. | |
| Verified on a 4090 (28L/d1280, AMP+grad_checkpoint): peak VRAM delta ~1MB. | |
| """ | |
| total = None | |
| for m in model.modules(): | |
| if isinstance(m, MoEFFN): | |
| inp = m.last_router_input | |
| m.last_router_input = None | |
| if inp is None or (aux_coef <= 0 and z_coef <= 0): | |
| continue | |
| router_in = inp.to(m.router.weight.dtype) if inp.dtype != m.router.weight.dtype else inp | |
| scores = m.router(router_in).float() | |
| probs = scores.softmax(dim=-1) | |
| importance = probs.mean(dim=0) | |
| top1 = probs.argmax(dim=-1) | |
| load = torch.bincount(top1, minlength=m.num_experts).to(importance.dtype) / max(1, top1.numel()) | |
| if aux_coef > 0: | |
| lb = aux_coef * m.num_experts * (load.detach() * importance).sum() | |
| total = lb if total is None else total + lb | |
| if z_coef > 0: | |
| zl = z_coef * (torch.logsumexp(scores, dim=-1) ** 2).mean() | |
| total = zl if total is None else total + zl | |
| return total if total is not None else 0.0 | |
| class Block(nn.Module): | |
| def __init__( | |
| self, | |
| d: int, | |
| h: int, | |
| r: int, | |
| attn_backend: str = DEFAULT_ATTN_BACKEND, | |
| sublinear_window: int = DEFAULT_SUBLINEAR_WINDOW, | |
| sublinear_stride: int = DEFAULT_SUBLINEAR_STRIDE, | |
| sublinear_max_anchors: int = DEFAULT_SUBLINEAR_MAX_ANCHORS, | |
| sublinear_chunk: int = DEFAULT_SUBLINEAR_CHUNK, | |
| sublinear_sinks: int = DEFAULT_SUBLINEAR_SINKS, | |
| sublinear_recent_anchors: int = DEFAULT_SUBLINEAR_RECENT_ANCHORS, | |
| sublinear_pooled_landmarks: bool = DEFAULT_SUBLINEAR_POOLED_LANDMARKS, | |
| moe_ffn: bool = DEFAULT_MOE_FFN, | |
| moe_experts: int = DEFAULT_MOE_EXPERTS, | |
| moe_top_k: int = DEFAULT_MOE_TOP_K, | |
| moe_mlp_mult: int = DEFAULT_MOE_MLP_MULT, | |
| moe_shared_experts: int = 0, | |
| moe_shared_mlp_mult: int = 0, | |
| tie_kv: bool = False, | |
| ): | |
| super().__init__() | |
| self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d) | |
| self.mha = TuneableAttentionMHA( | |
| d, | |
| h, | |
| r, | |
| attn_backend=attn_backend, | |
| sublinear_window=sublinear_window, | |
| sublinear_stride=sublinear_stride, | |
| sublinear_max_anchors=sublinear_max_anchors, | |
| sublinear_chunk=sublinear_chunk, | |
| sublinear_sinks=sublinear_sinks, | |
| sublinear_recent_anchors=sublinear_recent_anchors, | |
| sublinear_pooled_landmarks=sublinear_pooled_landmarks, | |
| tie_kv=tie_kv, | |
| ) | |
| self.ff = ( | |
| MoEFFN(d, mlp_mult=moe_mlp_mult, experts=moe_experts, top_k=moe_top_k, | |
| shared_experts=moe_shared_experts, shared_mlp_mult=moe_shared_mlp_mult) | |
| if moe_ffn else nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d)) | |
| ) | |
| def forward(self, x, mask, kv=None, use_cache=False, total_seq_len=None): | |
| if use_cache: | |
| y, new_kv = self.mha(self.ln1(x), mask, rel_bias_tokens=total_seq_len, kv_cache=kv, use_cache=True) | |
| x = x + y + self.ff(self.ln2(x + y)) | |
| return x, new_kv | |
| else: | |
| n = x.size(1) | |
| x = x + self.mha(self.ln1(x), mask, rel_bias_tokens=n) | |
| return x + self.ff(self.ln2(x)) | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| cfg, | |
| tie_weights: bool = False, | |
| attn_backend: str = DEFAULT_ATTN_BACKEND, | |
| grad_checkpoint: bool = False, | |
| sublinear_window: int = DEFAULT_SUBLINEAR_WINDOW, | |
| sublinear_stride: int = DEFAULT_SUBLINEAR_STRIDE, | |
| sublinear_max_anchors: int = DEFAULT_SUBLINEAR_MAX_ANCHORS, | |
| sublinear_chunk: int = DEFAULT_SUBLINEAR_CHUNK, | |
| sublinear_sinks: int = DEFAULT_SUBLINEAR_SINKS, | |
| sublinear_recent_anchors: int = DEFAULT_SUBLINEAR_RECENT_ANCHORS, | |
| sublinear_pooled_landmarks: bool = DEFAULT_SUBLINEAR_POOLED_LANDMARKS, | |
| anchor_memory: bool = DEFAULT_ANCHOR_MEMORY, | |
| anchor_stride: int = DEFAULT_ANCHOR_STRIDE, | |
| anchor_max: int = DEFAULT_ANCHOR_MAX, | |
| anchor_position: int = DEFAULT_ANCHOR_POSITION, | |
| moe_ffn: Optional[bool] = None, | |
| moe_experts: Optional[int] = None, | |
| moe_top_k: Optional[int] = None, | |
| moe_mlp_mult: Optional[int] = None, | |
| moe_shared_experts: Optional[int] = None, | |
| moe_shared_mlp_mult: Optional[int] = None, | |
| tie_kv: Optional[bool] = None, | |
| ): | |
| super().__init__() | |
| d, l, h, r = cfg["d"], cfg["layers"], cfg["heads"], cfg["rank"] | |
| if tie_kv is None: | |
| tie_kv = bool(cfg.get("tie_kv", False)) | |
| if moe_ffn is None: | |
| moe_ffn = bool(cfg.get("moe_ffn", DEFAULT_MOE_FFN)) | |
| if moe_experts is None: | |
| moe_experts = int(cfg.get("moe_experts", DEFAULT_MOE_EXPERTS)) | |
| if moe_top_k is None: | |
| moe_top_k = int(cfg.get("moe_top_k", DEFAULT_MOE_TOP_K)) | |
| if moe_mlp_mult is None: | |
| moe_mlp_mult = int(cfg.get("moe_mlp_mult", DEFAULT_MOE_MLP_MULT)) | |
| moe_experts = max(1, int(moe_experts)) | |
| moe_top_k = min(max(1, int(moe_top_k)), moe_experts) | |
| moe_mlp_mult = max(1, int(moe_mlp_mult)) | |
| if moe_shared_experts is None: | |
| moe_shared_experts = int(cfg.get("moe_shared_experts", 0)) | |
| if moe_shared_mlp_mult is None: | |
| moe_shared_mlp_mult = int(cfg.get("moe_shared_mlp_mult", 0)) | |
| moe_shared_experts = max(0, int(moe_shared_experts)) | |
| self.emb = nn.Embedding(VOCAB, d) | |
| self.blocks = nn.ModuleList([ | |
| Block( | |
| d, | |
| h, | |
| r, | |
| attn_backend=attn_backend, | |
| sublinear_window=sublinear_window, | |
| sublinear_stride=sublinear_stride, | |
| sublinear_max_anchors=sublinear_max_anchors, | |
| sublinear_chunk=sublinear_chunk, | |
| sublinear_sinks=sublinear_sinks, | |
| sublinear_recent_anchors=sublinear_recent_anchors, | |
| sublinear_pooled_landmarks=sublinear_pooled_landmarks, | |
| moe_ffn=bool(moe_ffn), | |
| moe_experts=moe_experts, | |
| moe_top_k=moe_top_k, | |
| moe_mlp_mult=moe_mlp_mult, | |
| moe_shared_experts=moe_shared_experts, | |
| moe_shared_mlp_mult=moe_shared_mlp_mult, | |
| tie_kv=bool(tie_kv), | |
| ) | |
| for _ in range(l) | |
| ]) | |
| self.ln = nn.LayerNorm(d) | |
| self.tie_weights = tie_weights | |
| self.attn_backend = attn_backend | |
| self.grad_checkpoint = grad_checkpoint | |
| self.sublinear_window = sublinear_window | |
| self.sublinear_stride = sublinear_stride | |
| self.sublinear_max_anchors = sublinear_max_anchors | |
| self.sublinear_chunk = sublinear_chunk | |
| self.sublinear_sinks = sublinear_sinks | |
| self.sublinear_recent_anchors = sublinear_recent_anchors | |
| self.sublinear_pooled_landmarks = bool(sublinear_pooled_landmarks) | |
| self.moe_ffn = bool(moe_ffn) | |
| self.moe_experts = moe_experts | |
| self.moe_top_k = moe_top_k | |
| self.moe_mlp_mult = moe_mlp_mult | |
| self.moe_shared_experts = moe_shared_experts | |
| self.anchor_memory_enabled = bool(anchor_memory) | |
| self.anchor_stride = int(anchor_stride) | |
| self.anchor_max = int(anchor_max) | |
| n_layers = int(cfg["layers"]) | |
| if int(anchor_position) < 0: | |
| self.anchor_position = n_layers // 2 | |
| else: | |
| self.anchor_position = min(int(anchor_position), n_layers - 1) | |
| if self.anchor_memory_enabled: | |
| am_cfg = AnchorMemoryConfig( | |
| d_model=int(cfg["d"]), | |
| heads=int(cfg["heads"]), | |
| anchor_stride=self.anchor_stride, | |
| max_anchors=self.anchor_max, | |
| ) | |
| self.anchor = AnchorMemoryLayer(am_cfg) | |
| else: | |
| self.anchor = None | |
| def forward(self, ids, mask, kv_caches=None, use_cache=False, total_seq_len=None): | |
| x = self.emb(ids) | |
| if not use_cache: | |
| for i, blk in enumerate(self.blocks): | |
| if self.grad_checkpoint and self.training: | |
| x = torch_checkpoint.checkpoint(lambda y, block=blk: block(y, mask), x, use_reentrant=False) | |
| else: | |
| x = blk(x, mask) | |
| if self.anchor is not None and i == self.anchor_position: | |
| if self.grad_checkpoint and self.training: | |
| x, _ = torch_checkpoint.checkpoint(self.anchor, x, use_reentrant=False) | |
| else: | |
| x, _ = self.anchor(x) | |
| return self.ln(x) | |
| new_kvs = [] | |
| for i, blk in enumerate(self.blocks): | |
| kv = kv_caches[i] if kv_caches else None | |
| x, kv_out = blk(x, mask, kv, use_cache=True, total_seq_len=total_seq_len) | |
| new_kvs.append(kv_out) | |
| if self.anchor is not None and i == self.anchor_position: | |
| x, _ = self.anchor(x) | |
| return self.ln(x), new_kvs | |
| class ARHead(nn.Module): | |
| def __init__(self, d, tie_weights: bool = False, embedding_weight: nn.Parameter = None): | |
| super().__init__() | |
| self.tie_weights = tie_weights | |
| if tie_weights and embedding_weight is not None: | |
| self.proj = nn.Linear(d, VOCAB, bias=False) | |
| self.proj.weight = embedding_weight | |
| else: | |
| self.proj = nn.Linear(d, VOCAB) | |
| def forward(self, h): | |
| return self.proj(h) | |
| class NATHead(nn.Module): | |
| def __init__(self, d, tie_weights: bool = False, embedding_weight: nn.Parameter = None): | |
| super().__init__() | |
| self.tie_weights = tie_weights | |
| if tie_weights and embedding_weight is not None: | |
| self.proj = nn.Linear(d, VOCAB, bias=False) | |
| self.proj.weight = embedding_weight | |
| else: | |
| self.proj = nn.Linear(d, VOCAB) | |
| def forward(self, h): | |
| return self.proj(h) | |
| class SATHead(nn.Module): | |
| def __init__(self, d, mode="var", tie_weights: bool = False, embedding_weight: nn.Parameter = None, mlp: bool = False): | |
| super().__init__() | |
| self.tie_weights = tie_weights | |
| self.mlp = bool(mlp) | |
| if self.mlp: | |
| self.proj = nn.Sequential( | |
| nn.Linear(d, d), | |
| nn.GELU(), | |
| nn.Linear(d, VOCAB), | |
| ) | |
| elif tie_weights and embedding_weight is not None: | |
| self.proj = nn.Linear(d, VOCAB, bias=False) | |
| self.proj.weight = embedding_weight | |
| else: | |
| self.proj = nn.Linear(d, VOCAB) | |
| self.gate = nn.Linear(d, 2) if mode == "var" else None | |
| def forward(self, h_last): | |
| return self.proj(h_last), (self.gate(h_last[:, 0]) if self.gate else None) | |
| # ───────────────────────── Masks ───────────────────────── | |
| def causal_mask(n, structured: bool = False): | |
| if structured: | |
| return StructuredAttentionMask("causal", q_len=n, k_len=n, query_base=0) | |
| return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1) | |
| def sat_mask(n, block=SAT_BLOCK, structured: bool = False): | |
| if structured: | |
| return StructuredAttentionMask("sat", q_len=n, k_len=n, query_base=0, block=block) | |
| idx = torch.arange(n, device=DEV) | |
| grp = idx.unsqueeze(0) // block | |
| allow = (grp.T == grp) | (grp.T > grp) | |
| return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0) | |
| def sat_mask_cached(new_len: int, cached_len: int, block=SAT_BLOCK, structured: bool = False): | |
| total_len = cached_len + new_len | |
| if structured: | |
| return StructuredAttentionMask("sat", q_len=new_len, k_len=total_len, query_base=cached_len, block=block) | |
| q_idx = torch.arange(cached_len, total_len, device=DEV).unsqueeze(1) | |
| k_idx = torch.arange(total_len, device=DEV).unsqueeze(0) | |
| q_grp = q_idx // block | |
| k_grp = k_idx // block | |
| allow = q_grp >= k_grp | |
| return torch.where(allow, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0) | |
| # ───────────────────────── Checkpoint helpers ───────────────────────── | |
| # ───────────────────────── Delta Checkpoints (weight-only, async) ───────────────────────── | |
| _delta_lock = threading.Lock() | |
| _delta_thread: Optional[threading.Thread] = None | |
| def _sha256_file(path: pathlib.Path) -> str: | |
| """Compute SHA256 of a file for integrity verification.""" | |
| h = hashlib.sha256() | |
| with open(path, "rb") as f: | |
| for chunk in iter(lambda: f.read(1 << 20), b""): | |
| h.update(chunk) | |
| return h.hexdigest() | |
| def _do_delta_save(tensors: dict, path: pathlib.Path, meta: dict): | |
| """Background worker: write weight-only checkpoint + checksum.""" | |
| try: | |
| path.parent.mkdir(exist_ok=True, parents=True) | |
| tmp = path.with_suffix(path.suffix + ".dtmp") | |
| torch.save({"weights": tensors, **meta}, tmp, _use_new_zipfile_serialization=False) | |
| digest = _sha256_file(tmp) | |
| tmp.replace(path) | |
| # Write sidecar checksum | |
| path.with_suffix(".sha256").write_text(f"{digest} {path.name}\n") | |
| print(f" [delta] saved {path.name} ({digest[:12]}...)") | |
| except Exception as e: | |
| print(f" [delta] FAILED {path.name}: {e}") | |
| def _delete_delta_artifacts(path: pathlib.Path): | |
| for sidecar in ( | |
| path, | |
| path.with_suffix(".sha256"), | |
| path.with_suffix(path.suffix + ".upload.sha256"), | |
| path.with_suffix(path.suffix + ".dtmp"), | |
| ): | |
| try: | |
| if sidecar.exists(): | |
| sidecar.unlink() | |
| except Exception: | |
| pass | |
| def _unwrap_compiled_module(module: nn.Module) -> nn.Module: | |
| """Return the original module when torch.compile wrapped it.""" | |
| return getattr(module, "_orig_mod", module) | |
| def _checkpoint_state_dict(module: nn.Module) -> dict: | |
| """State dict with stable keys, even when module is torch.compile'd.""" | |
| return _unwrap_compiled_module(module).state_dict() | |
| def _strip_orig_mod_prefix(state: dict) -> dict: | |
| """Accept older deltas accidentally saved from compiled modules.""" | |
| if not isinstance(state, dict): | |
| return state | |
| prefix = "_orig_mod." | |
| if not any(isinstance(k, str) and k.startswith(prefix) for k in state): | |
| return state | |
| return { | |
| (k[len(prefix):] if isinstance(k, str) and k.startswith(prefix) else k): v | |
| for k, v in state.items() | |
| } | |
| def _cat_legacy_weight_blocks(blocks: list) -> Optional[torch.Tensor]: | |
| if not blocks or not all(torch.is_tensor(t) for t in blocks): | |
| return None | |
| first = blocks[0] | |
| tail_shape = tuple(first.shape[1:]) | |
| if any(t.dtype != first.dtype or t.device != first.device for t in blocks): | |
| return None | |
| if any(t.ndim != first.ndim or tuple(t.shape[1:]) != tail_shape for t in blocks): | |
| return None | |
| return torch.cat(blocks, dim=0).contiguous() | |
| def _fuse_qkv_in_state_dict(sd: dict) -> dict: | |
| """Fold legacy q/k/v.weight triples into qkv.weight before loading/filtering.""" | |
| if not isinstance(sd, dict): | |
| return sd | |
| prefixes = set() | |
| for key in list(sd.keys()): | |
| for suffix in (".q.weight", ".k.weight", ".v.weight"): | |
| if isinstance(key, str) and key.endswith(suffix): | |
| prefixes.add(key[: -len(suffix)]) | |
| for prefix in prefixes: | |
| qk, kk, vk = prefix + ".q.weight", prefix + ".k.weight", prefix + ".v.weight" | |
| fk = prefix + ".qkv.weight" | |
| if qk in sd and kk in sd and vk in sd and fk not in sd: | |
| fused = _cat_legacy_weight_blocks([sd[qk], sd[kk], sd[vk]]) | |
| if fused is not None: | |
| sd[fk] = fused | |
| sd.pop(qk) | |
| sd.pop(kk) | |
| sd.pop(vk) | |
| return sd | |
| def _expand_dense_ffn_to_moe_state_dict(sd: dict, target_sd: dict) -> dict: | |
| if not isinstance(sd, dict) or not isinstance(target_sd, dict): | |
| return sd | |
| out = dict(sd) | |
| seeded_prefixes: set[str] = set() | |
| for target_key, target in target_sd.items(): | |
| if not isinstance(target_key, str) or ".ff.experts." not in target_key: | |
| continue | |
| match = re.match(r"(blocks\.\d+\.ff\.)experts\.\d+\.(0|2)\.(weight|bias)$", target_key) | |
| if not match: | |
| continue | |
| prefix = match.group(1) | |
| legacy_key = f"{prefix}{match.group(2)}.{match.group(3)}" | |
| src = out.get(legacy_key) | |
| if target_key not in out and torch.is_tensor(src) and torch.is_tensor(target) and tuple(src.shape) == tuple(target.shape): | |
| out[target_key] = src | |
| seeded_prefixes.add(prefix) | |
| for prefix in seeded_prefixes: | |
| router_key = prefix + "router.weight" | |
| router_target = target_sd.get(router_key) | |
| if router_key not in out and torch.is_tensor(router_target): | |
| out[router_key] = router_target.detach().clone() | |
| for legacy_suffix in ("0.weight", "0.bias", "2.weight", "2.bias"): | |
| out.pop(prefix + legacy_suffix, None) | |
| return out | |
| def _reconcile_shared_expert_keys(sd: dict, target_sd: dict) -> dict: | |
| """Warm-start compat between shared-expert (4.3) and shared-less (4.2) checkpoints. | |
| - Shared-less checkpoint into a model WITH shared experts: fill the missing | |
| `.ff.shared.` keys from the freshly initialised module values. The shared | |
| output layer is zero-initialised, so the warm-started model is numerically | |
| identical to the source checkpoint at step 0 (it then learns to contribute). | |
| - Shared-expert checkpoint into a model WITHOUT them: drop the `.ff.shared.` | |
| keys (everything transferable is kept; only the shared path is shed). | |
| """ | |
| if not isinstance(sd, dict) or not isinstance(target_sd, dict): | |
| return sd | |
| out = dict(sd) | |
| filled = 0 | |
| dropped = 0 | |
| for key, target in target_sd.items(): | |
| if isinstance(key, str) and ".ff.shared." in key and key not in out and torch.is_tensor(target): | |
| out[key] = target.detach().clone() | |
| filled += 1 | |
| for key in list(out.keys()): | |
| if isinstance(key, str) and ".ff.shared." in key and key not in target_sd: | |
| out.pop(key) | |
| dropped += 1 | |
| if filled: | |
| print(f"[warm-start] shared experts: {filled} keys init fresh (zero-init no-op)", flush=True) | |
| if dropped: | |
| print(f"[warm-start] shared experts: {dropped} checkpoint keys dropped (model has none)", flush=True) | |
| return out | |
| def _prepare_core_state_dict_for_load(core: nn.Module, sd: dict) -> dict: | |
| sd = _strip_orig_mod_prefix(sd) | |
| sd = _fuse_qkv_in_state_dict(dict(sd)) if isinstance(sd, dict) else sd | |
| if isinstance(sd, dict): | |
| sd = _expand_dense_ffn_to_moe_state_dict(sd, core.state_dict()) | |
| sd = _reconcile_shared_expert_keys(sd, core.state_dict()) | |
| return sd | |
| def _split_qkv_in_state_dict_for_test(sd: dict) -> dict: | |
| out = dict(sd) | |
| for key in list(out.keys()): | |
| if not isinstance(key, str) or not key.endswith(".qkv.weight"): | |
| continue | |
| base = key[: -len(".qkv.weight")] | |
| q, k, v = out.pop(key).chunk(3, dim=0) | |
| out[base + ".q.weight"] = q.clone() | |
| out[base + ".k.weight"] = k.clone() | |
| out[base + ".v.weight"] = v.clone() | |
| return out | |
| def _clone_opt_value(value): | |
| if torch.is_tensor(value): | |
| return value.detach().clone() | |
| return copy.deepcopy(value) | |
| def _optimizer_param_name_lookup(core, ar_h, sat_h, nat_h=None) -> dict[int, str]: | |
| out = {} | |
| for prefix, module in (("core", core), ("ar", ar_h), ("sat", sat_h), ("nat", nat_h)): | |
| if module is None: | |
| continue | |
| for name, param in module.named_parameters(): | |
| out.setdefault(id(param), f"{prefix}.{name}") | |
| return out | |
| def _optimizer_group_param_names(opt, core, ar_h, sat_h, nat_h=None) -> List[List[str]]: | |
| lookup = _optimizer_param_name_lookup(core, ar_h, sat_h, nat_h) | |
| return [ | |
| [lookup.get(id(param), f"<unknown:{id(param)}>") for param in group["params"]] | |
| for group in opt.param_groups | |
| ] | |
| def _legacy_names_for_current_param(name: str) -> List[str]: | |
| if name.endswith(".qkv.weight"): | |
| base = name[: -len(".qkv.weight")] | |
| return [base + ".q.weight", base + ".k.weight", base + ".v.weight"] | |
| return [name] | |
| def _fuse_legacy_optimizer_param_state(states: List[dict]) -> Optional[dict]: | |
| if len(states) < 2 or any(not isinstance(state, dict) for state in states): | |
| return None | |
| common = set(states[0]) | |
| for state in states[1:]: | |
| common &= set(state) | |
| out = {} | |
| for key in common: | |
| vals = [state[key] for state in states] | |
| if all(torch.is_tensor(v) for v in vals): | |
| shape = vals[0].shape | |
| if vals[0].ndim > 0 and all(v.shape == shape for v in vals[1:]): | |
| out[key] = torch.cat([v.detach().clone() for v in vals], dim=0).contiguous() | |
| else: | |
| out[key] = vals[0].detach().clone() | |
| else: | |
| out[key] = copy.deepcopy(vals[0]) | |
| return out | |
| def _fuse_legacy_qkv_optimizer_state(opt_state: dict, opt, core, ar_h, sat_h, nat_h=None) -> Optional[dict]: | |
| """Remap pre-QKV-fusion AdamW state to the current fused parameter layout.""" | |
| if not isinstance(opt_state, dict) or "state" not in opt_state or "param_groups" not in opt_state: | |
| return None | |
| current_sd = opt.state_dict() | |
| current_names = _optimizer_group_param_names(opt, core, ar_h, sat_h, nat_h) | |
| legacy_names = [ | |
| [legacy for name in group_names for legacy in _legacy_names_for_current_param(name)] | |
| for group_names in current_names | |
| ] | |
| if len(legacy_names) != len(opt_state.get("param_groups", [])): | |
| return None | |
| legacy_name_to_pid = {} | |
| for group_idx, names in enumerate(legacy_names): | |
| old_params = list(opt_state["param_groups"][group_idx].get("params", [])) | |
| if len(names) != len(old_params): | |
| return None | |
| for name, pid in zip(names, old_params): | |
| legacy_name_to_pid[name] = pid | |
| new_groups = [] | |
| for group_idx, current_group in enumerate(current_sd["param_groups"]): | |
| new_group = copy.deepcopy(opt_state["param_groups"][group_idx]) | |
| new_group["params"] = list(current_group["params"]) | |
| if "param_names" in new_group: | |
| new_group["param_names"] = list(current_names[group_idx]) | |
| new_groups.append(new_group) | |
| old_states = opt_state.get("state", {}) | |
| new_states = {} | |
| for group_names, current_group in zip(current_names, current_sd["param_groups"]): | |
| for name, new_pid in zip(group_names, current_group["params"]): | |
| legacy_set = _legacy_names_for_current_param(name) | |
| if len(legacy_set) > 1: | |
| old_pids = [legacy_name_to_pid.get(legacy) for legacy in legacy_set] | |
| if all(pid in old_states for pid in old_pids): | |
| fused = _fuse_legacy_optimizer_param_state([old_states[pid] for pid in old_pids]) | |
| if fused is not None: | |
| new_states[new_pid] = fused | |
| continue | |
| old_pid = legacy_name_to_pid.get(name) | |
| if old_pid in old_states: | |
| new_states[new_pid] = {key: _clone_opt_value(value) for key, value in old_states[old_pid].items()} | |
| return {"state": new_states, "param_groups": new_groups} | |
| def save_delta(core, ar_h, sat_h, nat_h, step: int, seen_tok: int, save_dir: pathlib.Path, phase_name: str): | |
| """Save weight-only delta in background thread. Non-blocking.""" | |
| global _delta_thread | |
| # Wait for any previous delta write to finish | |
| if _delta_thread is not None and _delta_thread.is_alive(): | |
| _delta_thread.join(timeout=60) | |
| # Snapshot weights to CPU (detach from GPU graph) | |
| with _delta_lock: | |
| tensors = { | |
| "core": {k: v.detach().cpu() for k, v in _checkpoint_state_dict(core).items()}, | |
| "ar": {k: v.detach().cpu() for k, v in _checkpoint_state_dict(ar_h).items()}, | |
| "sat": {k: v.detach().cpu() for k, v in _checkpoint_state_dict(sat_h).items()}, | |
| } | |
| if nat_h is not None: | |
| tensors["nat"] = {k: v.detach().cpu() for k, v in _checkpoint_state_dict(nat_h).items()} | |
| meta = {"step": step, "seen_tok": seen_tok, "wall_time": time.time(), "delta": True, **_tokenizer_payload()} | |
| path = save_dir / f"{phase_name}_delta_step{step:08d}.pt" | |
| _delta_thread = threading.Thread(target=_do_delta_save, args=(tensors, path, meta), daemon=True) | |
| _delta_thread.start() | |
| def _prune_delta_files_to_count(save_dir: pathlib.Path, phase_name: str, keep_count: int): | |
| """Keep only the newest keep_count complete delta files.""" | |
| try: | |
| pattern = f"{phase_name}_delta_step*.pt" | |
| deltas = sorted( | |
| [p for p in save_dir.glob(pattern) if p.stat().st_size > 0], | |
| key=lambda p: p.stat().st_mtime | |
| ) | |
| excess = len(deltas) - max(0, keep_count) | |
| if excess > 0: | |
| for p in deltas[:excess]: | |
| _delete_delta_artifacts(p) | |
| print(f" [delta-prune] deleted {p.name}") | |
| except Exception as e: | |
| print(f" [delta-prune] error: {e}") | |
| def _prune_deltas(save_dir: pathlib.Path, phase_name: str, max_deltas: int): | |
| """Keep only the most recent max_deltas delta files.""" | |
| if max_deltas is None or max_deltas <= 0: | |
| return | |
| _prune_delta_files_to_count(save_dir, phase_name, max_deltas) | |
| def _pinned_basenames(save_dir: pathlib.Path) -> set: | |
| try: | |
| txt = (save_dir / ".pinned").read_text() | |
| return {ln.strip().split("/")[-1] for ln in txt.splitlines() | |
| if ln.strip() and not ln.strip().startswith("#")} | |
| except Exception: | |
| return set() | |
| def _disk_hygiene(save_dir, phase_name: str, args, reason: str = ""): | |
| """In-file disk auto-prune so the training disk never wedges (a full disk makes | |
| Python unable to even start -> watchdog crash-loop). All AGILLM-4.2 disk pruning | |
| lives here in the single file rather than an external janitor that can silently die. | |
| Conservative: removes orphan *.tmp partial writes, full checkpoints beyond | |
| --max_ckpts, deltas beyond --delta_max_keep, stale side-cycle rounds and applied | |
| async-update artifacts, and escalates under --disk_free_floor_gb. NEVER deletes the | |
| newest full checkpoint, the resume/seed deltas, files younger than 2 min, or anything | |
| listed in <save_dir>/.pinned. Best-effort: never raises into the training loop.""" | |
| import shutil, glob as _glob | |
| try: | |
| save_dir = pathlib.Path(save_dir) | |
| ws = save_dir.parent | |
| pinned = _pinned_basenames(save_dir) | |
| floor = float(getattr(args, "disk_free_floor_gb", 0.0) or 0.0) | |
| now = time.time() | |
| def free_gb(): | |
| try: | |
| return shutil.disk_usage(str(save_dir)).free / (1024 ** 3) | |
| except Exception: | |
| return 1e9 | |
| def young(p, secs=120): | |
| try: | |
| return (now - p.stat().st_mtime) < secs | |
| except Exception: | |
| return True | |
| def rm(p): | |
| try: | |
| if p.name in pinned: | |
| return False | |
| if p.is_dir(): | |
| shutil.rmtree(p, ignore_errors=True) | |
| else: | |
| p.unlink() | |
| print(f" [disk] pruned {p.name}", flush=True) | |
| return True | |
| except Exception: | |
| return False | |
| def newest_first(paths): | |
| return sorted(paths, key=lambda p: p.stat().st_mtime, reverse=True) | |
| # 1) orphan partial writes (a live save's *.tmp is younger than 2 min) | |
| for t in save_dir.glob("*.tmp"): | |
| if not young(t): | |
| rm(t) | |
| # 2) full checkpoints beyond --max_ckpts (keep newest) | |
| keep_full = max(1, int(getattr(args, "max_ckpts", 2) or 2)) | |
| fulls = newest_first(list(save_dir.glob(f"{phase_name}_step*.pt"))) | |
| for p in fulls[keep_full:]: | |
| if not young(p): | |
| rm(p) | |
| # 3) deltas beyond --delta_max_keep | |
| keep_delta = max(1, int(getattr(args, "delta_max_keep", 1) or 1)) | |
| deltas = newest_first(list(save_dir.glob(f"{phase_name}_delta_step*.pt"))) | |
| for p in deltas[keep_delta:]: | |
| if not young(p): | |
| rm(p) | |
| # 4) transient side artifacts (side-cycle rounds, applied async updates) | |
| rounds = ws / "agillm41_side_rounds" | |
| rdirs = newest_first([d for d in rounds.glob("side_cycle_*") if d.is_dir()]) if rounds.exists() else [] | |
| for p in rdirs[2:]: | |
| rm(p) | |
| su = ws / "agillm41_side_updates" | |
| inc = su / "incoming" | |
| if inc.exists(): | |
| for p in newest_first(list(inc.glob("*.pt")))[4:]: | |
| if not young(p): | |
| rm(p) | |
| for sub in ("accepted", "rejected"): | |
| d = su / sub | |
| if d.exists(): | |
| for p in d.glob("*"): | |
| if not young(p, 600): | |
| rm(p) | |
| # 5) escalate under the free-space floor (transient + extra ckpts only) | |
| if floor > 0 and free_gb() < floor: | |
| print(f" [disk] below floor {floor:.0f}GB (free {free_gb():.1f}GB){(' ' + reason) if reason else ''}; escalating", flush=True) | |
| for p in rdirs[1:]: | |
| rm(p) | |
| for p in newest_first(list(save_dir.glob(f"{phase_name}_delta_step*.pt")))[1:]: | |
| if not young(p): | |
| rm(p) | |
| for p in newest_first(list(save_dir.glob(f"{phase_name}_step*.pt")))[1:]: | |
| if not young(p): | |
| rm(p) | |
| print(f" [disk] after escalation: {free_gb():.1f}GB free", flush=True) | |
| except Exception as e: | |
| print(f"[disk-hygiene] error: {e}", flush=True) | |
| def _build_val_set(source, chat_cfg, args, block): | |
| """Capture a fixed held-out token sample (val_seed stream) as (1, block+1) CPU batches. | |
| A fixed sample re-evaluated periodically gives a comparable loss curve over training.""" | |
| n = int(getattr(args, "val_tokens", 0) or 0) | |
| if n <= 0: | |
| return [] | |
| want = max(1, n // (block + 1)) * (block + 1) | |
| val_source = str(getattr(args, "val_source", "") or "").strip() | |
| use_hot_config = not bool(val_source) | |
| val_source = val_source or source | |
| print( | |
| f"[val] building held-out set from {val_source} " | |
| f"(hot_config={'on' if use_hot_config else 'off'}, seed {getattr(args, 'val_seed', 1337)})", | |
| flush=True, | |
| ) | |
| toks = [] | |
| try: | |
| for t in token_stream( | |
| val_source, want, seed=int(getattr(args, "val_seed", 1337)), | |
| chat=chat_cfg.get("chat", False), | |
| chat_messages_key=chat_cfg.get("key", "messages"), | |
| sft_add_generation_prompt=chat_cfg.get("gen_prompt", False), | |
| dataset_field_text=chat_cfg.get("text_field", "text"), | |
| streaming=True, | |
| use_hot_config=use_hot_config, | |
| ): | |
| toks.append(int(t)) | |
| if len(toks) >= want: | |
| break | |
| except Exception as e: | |
| print(f"[val] failed to build val set ({type(e).__name__}: {e}); validation disabled", flush=True) | |
| return [] | |
| batches = [torch.tensor(toks[i:i + block + 1], dtype=torch.long).unsqueeze(0) | |
| for i in range(0, len(toks) - block, block + 1)] | |
| print(f"[val] held-out set ready: {len(batches)} batches x {block + 1} tokens (seed {getattr(args, 'val_seed', 1337)})", flush=True) | |
| return batches | |
| def _run_validation(core, ar_h, val_batches, args, step): | |
| """Full-stack AR cross-entropy on the fixed held-out batches (no_grad, eval mode).""" | |
| if not val_batches: | |
| return None | |
| was_training = core.training | |
| core.eval(); ar_h.eval() | |
| tot_ce, tot_tok = 0.0, 0 | |
| try: | |
| with torch.no_grad(): | |
| for ids_cpu in val_batches: | |
| ids = ids_cpu.to(DEV) | |
| with amp(args.amp): | |
| h = core(ids, causal_mask(ids.size(1), structured=use_structured_masks(args))) | |
| ce = fused_ce(h[:, :-1], ar_h.proj.weight, ids[:, 1:]) | |
| ntok = ids.size(1) - 1 | |
| tot_ce += float(ce.detach()) * ntok | |
| tot_tok += ntok | |
| except Exception as e: | |
| print(f"[val] eval error ({type(e).__name__}: {e}); skipping this round", flush=True) | |
| if was_training: | |
| core.train(); ar_h.train() | |
| return None | |
| if was_training: | |
| core.train(); ar_h.train() | |
| ce = tot_ce / max(1, tot_tok) | |
| ppl = math.exp(min(20.0, ce)) | |
| print(f"[val] step={step} tokens={tot_tok} ce={ce:.4f} ppl={ppl:.2f}", flush=True) | |
| return ce | |
| def _load_module_state_compatible(module: nn.Module, state: dict, label: str = "module") -> int: | |
| """Load matching tensors only; skip obsolete untied vocab matrices for tied heads.""" | |
| if not isinstance(state, dict): | |
| return 0 | |
| state = _strip_orig_mod_prefix(state) | |
| tgt_sd = module.state_dict() | |
| tied = bool(getattr(module, "tie_weights", False)) | |
| filt = {} | |
| skipped = [] | |
| for k, v in state.items(): | |
| if tied and k == "proj.weight": | |
| skipped.append(k) | |
| continue | |
| if k in tgt_sd and hasattr(v, "shape") and v.shape == tgt_sd[k].shape: | |
| filt[k] = v | |
| else: | |
| skipped.append(k) | |
| if filt: | |
| module.load_state_dict(filt, strict=False) | |
| if tied and skipped: | |
| print(f"[ckpt] {label}: tied head active; skipped old untied tensors: {', '.join(skipped[:4])}{'...' if len(skipped)>4 else ''}") | |
| return len(filt) | |
| def load_delta(path: pathlib.Path, core, ar_h, sat_h, nat_h=None): | |
| """Load weight-only delta. Returns (step, seen_tok) or raises.""" | |
| # Verify checksum if sidecar exists | |
| sha_path = path.with_suffix(".sha256") | |
| if sha_path.exists(): | |
| expected = sha_path.read_text().split()[0] | |
| actual = _sha256_file(path) | |
| if expected != actual: | |
| raise ValueError(f"Checksum mismatch for {path.name}: expected {expected[:12]}... got {actual[:12]}...") | |
| print(f" [delta] checksum OK for {path.name}") | |
| ck = torch.load(path, map_location="cpu", weights_only=False) | |
| if not ck.get("delta"): | |
| raise ValueError(f"{path.name} is not a delta checkpoint") | |
| core.load_state_dict(_prepare_core_state_dict_for_load(core, ck["weights"]["core"])) | |
| _load_module_state_compatible(ar_h, ck["weights"].get("ar", {}), "ar") | |
| _load_module_state_compatible(sat_h, ck["weights"].get("sat", {}), "sat") | |
| if nat_h is not None: | |
| nat_sd = ck["weights"].get("nat") | |
| if nat_sd is not None: | |
| _load_module_state_compatible(nat_h, nat_sd, "nat") | |
| else: | |
| print("[nat] Delta has no NAT head; keeping fresh NAT initialization") | |
| _restore_tokenizer_from_ckpt(ck, path) | |
| return ck.get("step", 0), ck.get("seen_tok", 0) | |
| def _flush_delta(): | |
| """Wait for any in-flight delta save to complete.""" | |
| global _delta_thread | |
| if _delta_thread is not None and _delta_thread.is_alive(): | |
| print(" [delta] flushing in-flight write...") | |
| _delta_thread.join(timeout=120) | |
| def save_ckpt(path: pathlib.Path, core, ar_h, sat_h, nat_h, opt, scaler, meta): | |
| path.parent.mkdir(exist_ok=True, parents=True) | |
| tmp = path.with_suffix(path.suffix + ".tmp") | |
| tokenizer_payload = _tokenizer_payload() | |
| tokenizer_payload.setdefault("tokenizer_payload_schema", 2) | |
| state = { | |
| "core": _checkpoint_state_dict(core), "ar": _checkpoint_state_dict(ar_h), "sat": _checkpoint_state_dict(sat_h), | |
| "opt": opt.state_dict(), "scaler": scaler.state_dict(), | |
| "cfg": meta.get("cfg"), | |
| **tokenizer_payload, | |
| "transformers_version": __import__("transformers").__version__, | |
| "tokenizers_version": __import__("tokenizers").__version__, | |
| "tie_weights": meta.get("tie_weights", False), | |
| **{k: v for k, v in meta.items() if k not in ("cfg", "tie_weights")} | |
| } | |
| if nat_h is not None: | |
| state["nat"] = _checkpoint_state_dict(nat_h) | |
| torch.save(state, tmp, _use_new_zipfile_serialization=False) | |
| tmp.replace(path) | |
| _write_tokenizer_sidecar(path, {k: state.get(k) for k in ("tokenizer_payload_schema", "tokenizer_id", "tokenizer_json", "tokenizer_bundle", "tokenizer_special", "transformers_version", "tokenizers_version") if state.get(k) is not None}) | |
| (path.parent / "latest.json").write_text(json.dumps({"path": str(path), "step": meta["step"]})) | |
| print(f"\n✓ saved checkpoint {path.name}") | |
| def load_ckpt(path, core, ar_h, sat_h, opt, scaler, nat_h=None): | |
| p = _resolve_ckpt(path) or path | |
| ck = _try_load(p, map_location="cpu") | |
| if ck is None: raise FileNotFoundError(f"No valid checkpoint at {p}") | |
| core.load_state_dict(_prepare_core_state_dict_for_load(core, ck["core"])) | |
| _load_module_state_compatible(ar_h, ck.get("ar", {}), "ar") | |
| _load_module_state_compatible(sat_h, ck.get("sat", {}), "sat") | |
| if nat_h is not None: | |
| if "nat" in ck: | |
| _load_module_state_compatible(nat_h, ck["nat"], "nat") | |
| else: | |
| print("[nat] Checkpoint has no NAT head; keeping fresh NAT initialization") | |
| try: | |
| opt.load_state_dict(ck["opt"]) | |
| except Exception as exc: | |
| fused_opt = _fuse_legacy_qkv_optimizer_state(ck.get("opt"), opt, core, ar_h, sat_h, nat_h) | |
| if fused_opt is not None: | |
| try: | |
| opt.load_state_dict(fused_opt) | |
| print("[ckpt] Converted legacy q/k/v optimizer state to fused qkv layout") | |
| except Exception as exc2: | |
| print(f"[ckpt] WARNING: optimizer state incompatible; resetting optimizer ({type(exc).__name__}: {exc}; qkv remap failed: {type(exc2).__name__}: {exc2})") | |
| else: | |
| print(f"[ckpt] WARNING: optimizer state incompatible; resetting optimizer ({type(exc).__name__}: {exc})") | |
| try: | |
| scaler.load_state_dict(ck["scaler"]) | |
| except Exception as exc: | |
| print(f"[ckpt] WARNING: scaler state incompatible; resetting scaler ({type(exc).__name__}: {exc})") | |
| # Restore tokenizer from checkpoint (embedded json preferred; never raises) | |
| _restore_tokenizer_from_ckpt(ck, p) | |
| # Warn if transformers version changed since checkpoint was saved | |
| if "transformers_version" in ck: | |
| import transformers as _tf | |
| if ck["transformers_version"] != _tf.__version__: | |
| print(f"[tokenizer] WARNING: checkpoint saved with transformers={ck['transformers_version']}, now running {_tf.__version__}") | |
| return ck.get("step", 0), ck.get("seen_tok", 0), ck.get("wall_time", time.time()) | |
| def _safe_load_any(path: pathlib.Path, tgt: nn.Module, key: str | None = None): | |
| p = _resolve_ckpt(path) or path | |
| if not p.exists(): return 0 | |
| ck = _try_load(p, map_location="cpu") | |
| if ck is None: return 0 | |
| sd = ck.get(key, ck) if key else ck | |
| if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"] | |
| if isinstance(tgt, Encoder) or key == "core": | |
| sd = _prepare_core_state_dict_for_load(tgt, sd) | |
| else: | |
| sd = _strip_orig_mod_prefix(sd) | |
| sd = _fuse_qkv_in_state_dict(dict(sd)) if isinstance(sd, dict) else sd | |
| if not isinstance(sd, dict): | |
| return 0 | |
| tgt_sd = tgt.state_dict() | |
| filt = {k: v for k, v in sd.items() if k in tgt_sd and hasattr(v, "shape") and v.shape == tgt_sd[k].shape} | |
| if filt: tgt.load_state_dict(filt, strict=False) | |
| return len(filt) | |
| def infer_cfg_from_ckpt(path: pathlib.Path): | |
| p = _resolve_ckpt(path) or path | |
| if not p.exists(): return None | |
| sd = _try_load(p, map_location="cpu") | |
| if sd is None: return None | |
| if "cfg" in sd: return dict(sd["cfg"]) | |
| return None | |
| # ───────────────────────── Training Logic ───────────────────────── | |
| def _load_infer_head_state(module: nn.Module, state: dict, name: str): | |
| """Load inference heads across small checkpoint/schema drifts. | |
| Some older AGILLM-4 full checkpoints were saved before the current SAT/NAT | |
| head bias fields existed. For inference, preserve the old behavior by | |
| explicitly zero-filling missing bias tensors, while still failing on missing | |
| non-bias weights or shape mismatches. | |
| """ | |
| if not isinstance(state, dict): | |
| module.load_state_dict(state) | |
| return | |
| module_state = module.state_dict() | |
| patched = dict(state) | |
| zero_filled = [] | |
| shape_mismatch = [] | |
| for key, target in module_state.items(): | |
| if key not in patched and key.endswith('.bias') and torch.is_tensor(target): | |
| patched[key] = torch.zeros_like(target) | |
| zero_filled.append(key) | |
| for key, value in list(patched.items()): | |
| target = module_state.get(key) | |
| if target is None or not torch.is_tensor(value) or not torch.is_tensor(target): | |
| continue | |
| if tuple(value.shape) != tuple(target.shape): | |
| shape_mismatch.append(f"{key}: ckpt={tuple(value.shape)} model={tuple(target.shape)}") | |
| patched.pop(key) | |
| if shape_mismatch: | |
| raise RuntimeError(f"{name} checkpoint shape mismatch: " + "; ".join(shape_mismatch[:6])) | |
| loaded = module.load_state_dict(patched, strict=False) | |
| missing = [key for key in loaded.missing_keys if key not in zero_filled] | |
| if missing: | |
| raise RuntimeError(f"{name} checkpoint missing required keys: " + ", ".join(missing[:12])) | |
| notes = [] | |
| if zero_filled: | |
| notes.append("zero-filled " + ", ".join(zero_filled[:6])) | |
| if loaded.unexpected_keys: | |
| notes.append("ignored unexpected " + ", ".join(loaded.unexpected_keys[:6])) | |
| if notes: | |
| print(f"[infer-compat] {name}: " + "; ".join(notes), flush=True) | |
| def _sat_head_mlp_from_state(sd: dict) -> bool: | |
| sat_sd = sd.get("sat", {}) | |
| if sd.get("delta") and "weights" in sd: | |
| sat_sd = sd["weights"].get("sat", sat_sd) | |
| return any(str(key).startswith("proj.2.") for key in sat_sd) | |
| def _parse_grow_plan(s: str) -> List[int]: | |
| return sorted(set([int(x.strip()) for x in s.split(",") if x.strip() and int(x.strip()) >= 128])) | |
| def _count_enabled_params(*modules) -> int: | |
| seen_data_ptrs = set() | |
| total = 0 | |
| for m in modules: | |
| if m is None: | |
| continue | |
| for p in m.parameters(): | |
| if p.data_ptr() not in seen_data_ptrs: | |
| seen_data_ptrs.add(p.data_ptr()) | |
| total += p.numel() | |
| return total | |
| def _target_token_ratio(args) -> float: | |
| if getattr(args, "token_param_ratio", 0.0) and args.token_param_ratio > 0: | |
| return float(args.token_param_ratio) | |
| if str(getattr(args, "preset", "")).startswith("agillm4_"): | |
| return AGILLM4_TOKEN_PARAM_RATIO | |
| return 51.2 if args.chilla_max_double else 25.0 | |
| def _phase_freeze(core: nn.Module, *, freeze_core: bool, unfreeze_ln: bool, train_emb: bool): | |
| for p in core.parameters(): p.requires_grad = not freeze_core | |
| if freeze_core: | |
| if unfreeze_ln: | |
| for blk in core.blocks: | |
| for p in blk.ln1.parameters(): p.requires_grad = True | |
| for p in blk.ln2.parameters(): p.requires_grad = True | |
| for p in core.ln.parameters(): p.requires_grad = True | |
| if train_emb: | |
| for p in core.emb.parameters(): p.requires_grad = True | |
| def _side_update_unique_path(directory: pathlib.Path, name: str) -> pathlib.Path: | |
| directory.mkdir(parents=True, exist_ok=True) | |
| dest = directory / name | |
| if not dest.exists(): | |
| return dest | |
| stem, suffix = dest.stem, dest.suffix | |
| stamp = time.strftime("%Y%m%d-%H%M%S", time.gmtime()) | |
| for idx in range(1000): | |
| candidate = directory / f"{stem}.{stamp}.{idx}{suffix}" | |
| if not candidate.exists(): | |
| return candidate | |
| return directory / f"{stem}.{stamp}.{os.getpid()}{suffix}" | |
| def _side_update_move(path: pathlib.Path, directory: pathlib.Path) -> pathlib.Path: | |
| dest = _side_update_unique_path(directory, path.name) | |
| try: | |
| path.replace(dest) | |
| except OSError: | |
| import shutil | |
| shutil.move(str(path), str(dest)) | |
| return dest | |
| def _apply_async_side_updates(core: nn.Module, cfg: dict, args, step: int) -> list[dict]: | |
| update_dir_s = str(getattr(args, "async_update_dir", "") or "").strip() | |
| alpha = float(getattr(args, "async_update_alpha", 1.0) or 0.0) | |
| if not update_dir_s or alpha <= 0.0: | |
| return [] | |
| update_dir = pathlib.Path(update_dir_s) | |
| if not update_dir.exists(): | |
| return [] | |
| max_updates = max(1, int(getattr(args, "async_update_max_per_check", 1) or 1)) | |
| max_age = float(getattr(args, "async_update_max_age_sec", 0.0) or 0.0) | |
| accepted_dir = pathlib.Path(getattr(args, "async_update_accepted_dir", "") or (update_dir.parent / "accepted")) | |
| rejected_dir = pathlib.Path(getattr(args, "async_update_rejected_dir", "") or (update_dir.parent / "rejected")) | |
| param_map = dict(core.named_parameters()) | |
| buffer_map = dict(core.named_buffers()) | |
| now = time.time() | |
| applied: list[dict] = [] | |
| candidates = sorted( | |
| [p for p in update_dir.glob("*.pt") if p.is_file() and not p.name.endswith(".tmp")], | |
| key=lambda p: p.stat().st_mtime, | |
| ) | |
| for path in candidates[:max_updates]: | |
| reject_reason = "" | |
| try: | |
| if max_age > 0 and now - path.stat().st_mtime > max_age: | |
| reject_reason = f"stale update older than {max_age:g}s" | |
| raise ValueError(reject_reason) | |
| upd = torch.load(path, map_location="cpu", weights_only=False) | |
| kind = upd.get("kind") | |
| if kind not in {"agillm35_dblock_slice_update", "agillm4_dblock_slice_update", "agillm41_dblock_slice_update"}: | |
| raise ValueError(f"bad update kind {kind!r}") | |
| if dict(upd.get("cfg", {})) != dict(cfg): | |
| raise ValueError("cfg mismatch") | |
| block_state = upd.get("block_state") | |
| if not isinstance(block_state, dict) or not block_state: | |
| raise ValueError("missing block_state") | |
| changed = 0 | |
| with torch.no_grad(): | |
| for key, value in block_state.items(): | |
| target = param_map.get(key) | |
| if target is None: | |
| target = buffer_map.get(key) | |
| if target is None: | |
| raise KeyError(f"unknown core key {key}") | |
| if tuple(value.shape) != tuple(target.shape): | |
| raise ValueError(f"{key} shape mismatch update={tuple(value.shape)} target={tuple(target.shape)}") | |
| src = value.to(device=target.device, dtype=target.dtype, non_blocking=True) | |
| if alpha >= 1.0: | |
| target.copy_(src) | |
| else: | |
| target.lerp_(src, alpha) | |
| changed += 1 | |
| del src | |
| dest = _side_update_move(path, accepted_dir) | |
| rec = { | |
| "path": str(dest), | |
| "worker_id": upd.get("worker_id"), | |
| "block_id": upd.get("block_id"), | |
| "layers": upd.get("layers"), | |
| "tokens": int(upd.get("tokens") or 0), | |
| "tok_per_sec": float(upd.get("tok_per_sec") or 0.0), | |
| "alpha": alpha, | |
| "keys": changed, | |
| } | |
| applied.append(rec) | |
| print(json.dumps({"event": "async_side_update_applied", "step": step, **rec}), flush=True) | |
| except Exception as exc: | |
| try: | |
| dest = _side_update_move(path, rejected_dir) | |
| except Exception: | |
| dest = path | |
| print( | |
| json.dumps( | |
| { | |
| "event": "async_side_update_rejected", | |
| "step": step, | |
| "path": str(dest), | |
| "error": reject_reason or str(exc), | |
| } | |
| ), | |
| flush=True, | |
| ) | |
| return applied | |
| def _optimizer_param_groups(core, ar_h, sat_h, lr_core: float, lr_head: float, nat_h=None): | |
| # Shared/tied vocab projections must appear in only one optimizer group. | |
| # VRAM-first AGILLM-4 uses one embedding/projection tensor for AR/SAT/NAT. | |
| seen: set[int] = set() | |
| groups = [] | |
| def add(params, lr): | |
| unique = [] | |
| for p in params: | |
| if not p.requires_grad: | |
| continue | |
| key = id(p) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| unique.append(p) | |
| if unique: | |
| groups.append({"params": unique, "lr": lr}) | |
| add(core.parameters(), lr_core) | |
| add(ar_h.parameters(), lr_head) | |
| add(sat_h.parameters(), lr_head) | |
| if nat_h is not None: | |
| add(nat_h.parameters(), lr_head) | |
| return groups | |
| class PowerStep(torch.optim.Optimizer): | |
| """Memory-efficient optimizer (arXiv:2605.10335): heavy-ball momentum + signed | |
| power transform, a SINGLE buffer (no Adam second moment). Update: | |
| m_t = gamma*m_{t-1} + g_t ; theta -= lr * (sign(m)*|m|^beta + wd*theta) | |
| beta in (0,1) gives Adam-like coordinate adaptivity; beta=1 -> SGD-momentum, | |
| beta=0 -> signSGD-momentum. Half the optimizer state of Adam. | |
| Faithful AGILLM-4.2 dblock-step benchmark (small model, real EDM objective, bf16): | |
| converged faster and to a LOWER loss than AdamW/paged_adamw8bit (EMA 6.6 vs 8.7-9.5). | |
| Note: its update scale differs from Adam, so it needs its own LR (~1e-3 vs Adam's | |
| 3e-4). The fp32 momentum buffer here lives in VRAM (~+3GB at 1B params); for the | |
| 24GB 4090 a paged or int8-quantized buffer (per the paper) is the deployment path.""" | |
| def __init__(self, params, lr=1e-3, momentum=0.9, beta=0.1, weight_decay=0.0, | |
| int8=False, paged=False): | |
| if not 0.0 <= beta <= 1.0: | |
| raise ValueError(f"beta must be in [0,1], got {beta}") | |
| if int8 and paged: | |
| raise ValueError("choose at most one of PowerStep int8 / paged") | |
| # Memory modes for the single momentum buffer (VRAM is the constraint; RAM is cheap): | |
| # default -> fp32 buffer in VRAM (fastest). | |
| # int8=True -> blockwise-int8 buffer in VRAM (paper's headline; ~1/4 VRAM). | |
| # paged=True -> fp32 buffer in pinned CPU RAM (~0 persistent VRAM; spends RAM+PCIe). | |
| self._int8 = bool(int8); self._paged = bool(paged) | |
| if self._int8: | |
| import bitsandbytes.functional as _bnbF | |
| self._bnbF = _bnbF | |
| super().__init__(params, dict(lr=lr, momentum=momentum, beta=beta, weight_decay=weight_decay)) | |
| def step(self, closure=None): | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() | |
| EPS = 1e-12 | |
| for group in self.param_groups: | |
| lr = group["lr"]; gamma = group["momentum"]; beta = group["beta"]; wd = group["weight_decay"] | |
| if self._int8 or self._paged: | |
| # Per-tensor path (blockwise-int8 in VRAM, or fp32 buffer in CPU RAM). | |
| for p in group["params"]: | |
| if p.grad is None: | |
| continue | |
| g = p.grad | |
| st = self.state[p] | |
| if self._int8: | |
| m = (torch.zeros_like(p, dtype=torch.float32) if "mq" not in st | |
| else self._bnbF.dequantize_blockwise(st["mq"], st["mstate"])) | |
| m.mul_(gamma).add_(g.float()) | |
| u = (m * (m.abs() + EPS).pow(beta - 1.0)).to(p.dtype) | |
| st["mq"], st["mstate"] = self._bnbF.quantize_blockwise(m) | |
| else: | |
| if "m" not in st: | |
| st["m"] = torch.zeros(p.shape, dtype=torch.float32, | |
| pin_memory=torch.cuda.is_available()) | |
| m = st["m"].to(p.device, non_blocking=True) | |
| m.mul_(gamma).add_(g.float()) | |
| u = (m * (m.abs() + EPS).pow(beta - 1.0)).to(p.dtype) | |
| st["m"].copy_(m, non_blocking=True) | |
| if wd != 0: | |
| p.mul_(1.0 - lr * wd) | |
| p.add_(u, alpha=-lr) | |
| continue | |
| # Fast multi-tensor (foreach) path for the default in-VRAM fp32 buffer: | |
| # batches the elementwise update across all params -> few kernel launches, | |
| # matching fused optimizers instead of one launch set per parameter. | |
| params, grads, ms = [], [], [] | |
| for p in group["params"]: | |
| if p.grad is None: | |
| continue | |
| st = self.state[p] | |
| if "m" not in st: | |
| st["m"] = torch.zeros_like(p, memory_format=torch.preserve_format) | |
| params.append(p); grads.append(p.grad); ms.append(st["m"]) | |
| if not params: | |
| continue | |
| # m = gamma*m + g | |
| torch._foreach_mul_(ms, gamma) | |
| torch._foreach_add_(ms, grads) | |
| # u = sign(m)*|m|^beta = m * (|m|+eps)^(beta-1) (avoids a separate sign pass) | |
| absm = torch._foreach_abs(ms) | |
| torch._foreach_add_(absm, EPS) | |
| torch._foreach_pow_(absm, beta - 1.0) | |
| us = torch._foreach_mul(ms, absm) | |
| if wd != 0: | |
| torch._foreach_mul_(params, 1.0 - lr * wd) | |
| torch._foreach_add_(params, us, alpha=-lr) | |
| return loss | |
| def make_optimizer(args, core, ar_h, sat_h, lr_core: float, lr_head: float, nat_h=None): | |
| groups = _optimizer_param_groups(core, ar_h, sat_h, lr_core, lr_head, nat_h) | |
| opt_name = getattr(args, "optimizer", "adamw") | |
| if opt_name == "adamw": | |
| return torch.optim.AdamW(groups) | |
| if opt_name == "powerstep": | |
| return PowerStep(groups, | |
| momentum=float(getattr(args, "powerstep_momentum", 0.9)), | |
| beta=float(getattr(args, "powerstep_beta", 0.1)), | |
| weight_decay=float(getattr(args, "weight_decay", 0.0) or 0.0), | |
| int8=bool(getattr(args, "powerstep_int8", False)), | |
| paged=bool(getattr(args, "powerstep_paged", False))) | |
| if opt_name in {"adamw8bit", "paged_adamw8bit"}: | |
| try: | |
| import bitsandbytes as bnb | |
| except Exception as exc: | |
| raise RuntimeError( | |
| f"--optimizer {opt_name} requires bitsandbytes. Install it in the training env first." | |
| ) from exc | |
| if opt_name == "paged_adamw8bit": | |
| return bnb.optim.PagedAdamW8bit(groups) | |
| return bnb.optim.AdamW8bit(groups) | |
| raise ValueError(f"unknown optimizer: {opt_name}") | |
| def _nat_ids_for_training(ids: torch.Tensor, max_tokens: int) -> torch.Tensor: | |
| if max_tokens and max_tokens > 0 and ids.size(1) > max_tokens: | |
| return ids[:, -max_tokens:] | |
| return ids | |
| def _train_phase( | |
| args, phase_name: str, | |
| core, ar_h, sat_h, nat_h, opt, scaler, | |
| start_step, seen_tok, resume_wall_time, | |
| cfg, source, steps, block_size, batch_size, | |
| chat_cfg: dict, | |
| max_ckpts: int, | |
| target_tokens_override: Optional[int] = None, | |
| tie_weights: bool = False, | |
| streaming: bool = True | |
| ): | |
| BLOCK = block_size | |
| BATCH = batch_size | |
| if target_tokens_override is not None: | |
| target_tokens = target_tokens_override | |
| else: | |
| ratio = _target_token_ratio(args) | |
| param_count = _count_enabled_params(core, ar_h, sat_h, nat_h) | |
| target_tokens = int(ratio * param_count) | |
| print(f"[{phase_name}] token_param_ratio={ratio:g} param_count={param_count:,} target_tokens={target_tokens:,}") | |
| if steps: | |
| phase_target_tokens = steps * BLOCK * BATCH | |
| total_tokens_needed = seen_tok + phase_target_tokens | |
| else: | |
| total_tokens_needed = target_tokens | |
| if total_tokens_needed <= seen_tok: | |
| print(f"[{phase_name}] target {total_tokens_needed} already reached.") | |
| return start_step, seen_tok, resume_wall_time | |
| data_seed = int(getattr(args, "data_seed", 42)) | |
| if data_seed < 0: | |
| # Streaming restarts from the dataset head with a fixed shuffle seed, so every | |
| # restart re-trains the same early data. Derive a per-resume seed instead: | |
| # deterministic for a given checkpoint, different across restarts. | |
| data_seed = 42 + int(start_step) | |
| print(f"[data] per-restart shuffle seed {data_seed} (derived from resume step)", flush=True) | |
| val_batches = _build_val_set(source, chat_cfg, args, BLOCK) | |
| last_val_mono = time.monotonic() | |
| stream = token_stream( | |
| source, total_tokens_needed, seed=data_seed, | |
| chat=chat_cfg.get("chat", False), | |
| chat_messages_key=chat_cfg.get("key", "messages"), | |
| sft_add_generation_prompt=chat_cfg.get("gen_prompt", False), | |
| dataset_field_text=chat_cfg.get("text_field", "text"), | |
| streaming=streaming | |
| ) | |
| ce_tok = nn.CrossEntropyLoss(label_smoothing=0.1) | |
| ce_gate = nn.CrossEntropyLoss() | |
| ctc = nn.CTCLoss(blank=BLANK, zero_infinity=True) | |
| pbar = SafeProgress(total=total_tokens_needed, initial=seen_tok, unit="tok") | |
| grow_plan = _parse_grow_plan(args.grow_plan) if args.auto_grow else [] | |
| buf: list[int] = [] | |
| batch_accum: list[list[int]] = [] | |
| step = start_step | |
| steps_since_last_grow = 0 | |
| oom_retries = 0 | |
| MAX_OOM_RETRIES = 2 | |
| now_wall = time.time() | |
| last_save_mono = time.monotonic() - (now_wall - (resume_wall_time or now_wall)) | |
| last_delta_step = start_step | |
| last_heartbeat_mono = time.monotonic() | |
| _disk_hygiene(pathlib.Path(args.save_dir), phase_name, args, reason="startup") | |
| if val_batches: | |
| _run_validation(core, ar_h, val_batches, args, step) | |
| print(f"[{phase_name}] Starting. Goal: {total_tokens_needed:,} tokens. Batch={BATCH}, Block={BLOCK}") | |
| print( | |
| f"[{phase_name}] AR_ONLY={args.ar_only}, SAT_EVERY={args.sat_every}, " | |
| f"NAT_EVERY={args.nat_every}, TIE_WEIGHTS={tie_weights}, STREAMING={streaming}" | |
| ) | |
| _flush_flag = [False] | |
| def _on_flush_signal(signum, frame): | |
| _flush_flag[0] = True | |
| print(f"\n[{phase_name}] flush signal received; will checkpoint at next step") | |
| try: | |
| signal.signal(signal.SIGUSR1, _on_flush_signal) | |
| print(f"[{phase_name}] on-demand flush ready: kill -USR1 {os.getpid()} or touch {pathlib.Path(args.save_dir) / 'FLUSH_NOW'}") | |
| except (ValueError, OSError): | |
| pass | |
| _DBS = _dblock_init(core, args) if getattr(args,'dblock',False) else None | |
| if DEV.type == "cuda": | |
| try: | |
| torch.cuda.empty_cache() | |
| torch.cuda.reset_peak_memory_stats() | |
| print( | |
| f"[vram] training-start cache cleared: " | |
| f"alloc={torch.cuda.memory_allocated() / (1024**3):.2f}GB " | |
| f"reserved={torch.cuda.memory_reserved() / (1024**3):.2f}GB " | |
| f"structured_masks={use_structured_masks(args)}", | |
| flush=True, | |
| ) | |
| except Exception: | |
| pass | |
| while seen_tok < total_tokens_needed: | |
| _profile_batch = _DBS is not None and int(getattr(args, "profile_steps", 0) or 0) > 0 and int(_DBS.get("profile_n", 0)) < int(getattr(args, "profile_steps", 0) or 0) | |
| _data_t = time.perf_counter() if _profile_batch else None | |
| try: | |
| while len(buf) < BLOCK: | |
| buf.append(next(stream)) | |
| except StopIteration: | |
| break | |
| if _profile_batch: | |
| try: | |
| import dblocks_train as _db_prof | |
| _db_prof._profile_add(_DBS, "data_stream", time.perf_counter() - _data_t) | |
| except Exception: | |
| pass | |
| seq = buf[:BLOCK] | |
| buf = buf[BLOCK:] | |
| batch_accum.append(seq) | |
| if len(batch_accum) < BATCH: | |
| continue | |
| _tensor_t = time.perf_counter() if _profile_batch else None | |
| ids = torch.tensor(batch_accum, device=DEV) | |
| if _profile_batch: | |
| if DEV.type == "cuda": | |
| try: | |
| torch.cuda.synchronize() | |
| except Exception: | |
| pass | |
| try: | |
| import dblocks_train as _db_prof | |
| _db_prof._profile_add(_DBS, "tensor", time.perf_counter() - _tensor_t) | |
| except Exception: | |
| pass | |
| batch_accum = [] | |
| tgt_ar = ids.clone() | |
| try: | |
| if getattr(args, "dblock", False): | |
| loss_value = _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, _DBS) | |
| else: | |
| with amp(args.amp): | |
| h_ar = core(ids, causal_mask(ids.size(1), structured=use_structured_masks(args))) | |
| logits_ar = ar_h(h_ar)[:, :-1] | |
| loss_ar = ce_tok(logits_ar.reshape(-1, VOCAB), tgt_ar[:, 1:].reshape(-1)) | |
| loss_value = float(loss_ar.detach().item()) | |
| _aux = _collect_moe_aux(core, getattr(args,'moe_aux_coef',0.0), getattr(args,'moe_z_coef',0.0)) | |
| if torch.is_tensor(_aux): | |
| loss_ar = loss_ar + _aux.to(loss_ar.dtype) | |
| scaler.scale(loss_ar).backward() | |
| del h_ar, logits_ar, loss_ar | |
| do_sat = (not args.ar_only) and (args.sat_every <= 1 or ((step + 1) % args.sat_every == 0)) | |
| if do_sat: | |
| # Same AR+SAT objective as a summed loss, but sequential backward keeps | |
| # only one core-forward activation graph live at a time on 24GB cards. | |
| with amp(args.amp): | |
| h_sat = core(ids, sat_mask(ids.size(1), structured=use_structured_masks(args))) | |
| sat_ctx = h_sat[:, :-SAT_BLOCK] | |
| tgt_sat = ids[:, SAT_BLOCK:] | |
| if sat_ctx.size(1) == 0 or sat_ctx.size(1) != tgt_sat.size(1): | |
| sat_ctx = h_sat[:, :-1] | |
| tgt_sat = ids[:, 1:] | |
| logits_sat = sat_h.proj(sat_ctx) | |
| loss_sat = ce_tok(logits_sat.reshape(-1, VOCAB), tgt_sat.reshape(-1)) | |
| if sat_h.gate is not None: | |
| sat_gate_ctx = sat_ctx[:, ::SAT_BLOCK] | |
| gate_targets = torch.ones( | |
| sat_gate_ctx.numel() // sat_gate_ctx.size(-1), device=DEV, dtype=torch.long | |
| ) | |
| loss_sat += EMIT_LAMBDA * ce_gate( | |
| sat_h.gate(sat_gate_ctx.reshape(-1, sat_gate_ctx.size(-1))), gate_targets | |
| ) | |
| loss_value += float(loss_sat.detach().item()) | |
| _aux = _collect_moe_aux(core, getattr(args,'moe_aux_coef',0.0), getattr(args,'moe_z_coef',0.0)) | |
| if torch.is_tensor(_aux): | |
| loss_sat = loss_sat + _aux.to(loss_sat.dtype) | |
| scaler.scale(loss_sat).backward() | |
| del h_sat, logits_sat, loss_sat | |
| do_nat = ( | |
| nat_h is not None | |
| and (not args.ar_only) | |
| and args.nat_every > 0 | |
| and (args.nat_every <= 1 or ((step + 1) % args.nat_every == 0)) | |
| ) | |
| if do_nat: | |
| nat_ids = _nat_ids_for_training(ids, args.nat_max_tokens) | |
| with amp(args.amp): | |
| # Mask-predict (CMLM) objective: corrupt a fraction of positions | |
| # with BLANK and reconstruct them from surrounding context. The | |
| # old CTC objective fed the clean target as input, so the head | |
| # only learned to copy and collapsed at inference on all-BLANK | |
| # input. This conditions on real context and cannot collapse. | |
| nat_in = nat_ids.clone() | |
| ratio = min(max(float(args.nat_mask_ratio), 0.05), 0.95) | |
| mask = torch.rand(nat_in.shape, device=nat_in.device) < ratio | |
| if not bool(mask.any()): | |
| mask[..., -1] = True | |
| nat_in[mask] = BLANK | |
| h_nat = core(nat_in, None) | |
| logits_nat = nat_h(h_nat) | |
| loss_nat = F.cross_entropy(logits_nat[mask].float(), nat_ids[mask]) | |
| loss_nat = float(args.nat_loss_weight) * loss_nat | |
| loss_value += float(loss_nat.detach().item()) | |
| _aux = _collect_moe_aux(core, getattr(args,'moe_aux_coef',0.0), getattr(args,'moe_z_coef',0.0)) | |
| if torch.is_tensor(_aux): | |
| loss_nat = loss_nat + _aux.to(loss_nat.dtype) | |
| scaler.scale(loss_nat).backward() | |
| del nat_ids, nat_in, mask, h_nat, logits_nat, loss_nat | |
| scaler.unscale_(opt) | |
| nn.utils.clip_grad_norm_([p for group in opt.param_groups for p in group["params"]], 1.0) | |
| scaler.step(opt) | |
| scaler.update() | |
| opt.zero_grad(set_to_none=True) | |
| except RuntimeError as e: | |
| msg = str(e).lower() | |
| if "out of memory" in msg or "cuda error" in msg: | |
| batch_accum = [] | |
| opt.zero_grad(set_to_none=True) | |
| scaler = GradScaler(enabled=(args.amp and _needs_grad_scaler())) | |
| if DEV.type == "cuda": | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| oom_retries += 1 | |
| if oom_retries <= MAX_OOM_RETRIES: | |
| print(f"\n[{phase_name} OOM] Retry {oom_retries}/{MAX_OOM_RETRIES} at Batch={BATCH}, clearing VRAM...") | |
| time.sleep(2) | |
| continue | |
| oom_retries = 0 | |
| if BATCH > 1: | |
| print(f"\n[{phase_name} OOM] Reducing Batch: {BATCH} -> {BATCH - 1} (after {MAX_OOM_RETRIES} retries)") | |
| BATCH -= 1 | |
| time.sleep(2) | |
| else: | |
| new_block = max(128, int(BLOCK * 0.8)) | |
| new_block = max(128, (new_block // 128) * 128) | |
| if new_block >= BLOCK: | |
| new_block = max(128, BLOCK - 128) | |
| print(f"\n[{phase_name} OOM] Reducing Block: {BLOCK} -> {new_block}") | |
| BLOCK = new_block | |
| time.sleep(2) | |
| steps_since_last_grow = 0 | |
| continue | |
| raise | |
| step += 1 | |
| # Periodic tokenizer spot-check: verify training data has spaces | |
| if step % 1000 == 0: | |
| try: | |
| sample_text = tok.decode(ids[0][:50].tolist(), skip_special_tokens=True) | |
| if len(sample_text) > 20 and " " not in sample_text: | |
| print(f"\n[tokenizer] ALERT step {step}: decoded batch has NO SPACES!") | |
| print(f" Sample: {repr(sample_text[:80])}") | |
| print(" Check transformers version!") | |
| except Exception: | |
| pass | |
| oom_retries = 0 | |
| toks_processed = BLOCK * BATCH | |
| seen_tok += toks_processed | |
| pbar.set_postfix(loss=f"{loss_value:.3f}", B=BATCH, L=BLOCK) | |
| pbar.update(toks_processed) | |
| async_every = int(getattr(args, "async_update_every_steps", 0) or 0) | |
| if async_every > 0 and (step % async_every) == 0: | |
| _apply_async_side_updates(core, cfg, args, step) | |
| empty_cache_every = int(getattr(args, "empty_cache_every_steps", 0) or 0) | |
| if DEV.type == "cuda" and empty_cache_every > 0 and (step % empty_cache_every) == 0: | |
| try: | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| heartbeat_every = int(getattr(args, "heartbeat_every_sec", 300) or 0) | |
| now_mono = time.monotonic() | |
| if heartbeat_every > 0 and now_mono - last_heartbeat_mono >= heartbeat_every: | |
| mem = "" | |
| if DEV.type == "cuda": | |
| try: | |
| mem = ( | |
| f" gpu_alloc={torch.cuda.memory_allocated() / (1024**3):.2f}GB" | |
| f" gpu_reserved={torch.cuda.memory_reserved() / (1024**3):.2f}GB" | |
| f" gpu_peak={torch.cuda.max_memory_allocated() / (1024**3):.2f}GB" | |
| ) | |
| except Exception: | |
| mem = "" | |
| try: | |
| heartbeat_payload = { | |
| "schema": "agillm.run_state.v1", | |
| "model": "AGILLM4.3", | |
| "phase": "training", | |
| "trainer_phase": phase_name, | |
| "pid": int(os.getpid()), | |
| "step": int(step), | |
| "seen_tok": int(seen_tok), | |
| "loss": float(loss_value), | |
| "batch_size": int(BATCH), | |
| "block": int(BLOCK), | |
| "dblock": bool(getattr(args, "dblock", False)), | |
| "structured_masks": bool(use_structured_masks(args)), | |
| "device": str(DEV), | |
| "save_dir": str(args.save_dir), | |
| "updated_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| } | |
| if DEV.type == "cuda": | |
| try: | |
| heartbeat_payload["gpu"] = { | |
| "allocated_gb": round(torch.cuda.memory_allocated() / (1024**3), 4), | |
| "reserved_gb": round(torch.cuda.memory_reserved() / (1024**3), 4), | |
| "peak_allocated_gb": round(torch.cuda.max_memory_allocated() / (1024**3), 4), | |
| } | |
| except Exception: | |
| pass | |
| hb_path = pathlib.Path(args.save_dir) / "run_state.json" | |
| hb_tmp = hb_path.with_suffix(".json.tmp") | |
| hb_tmp.write_text(json.dumps(heartbeat_payload, sort_keys=True) + "\n") | |
| hb_tmp.replace(hb_path) | |
| top_path = pathlib.Path(args.save_dir).parent / "agillm43_run_state.json" | |
| merged = {} | |
| if top_path.exists(): | |
| try: | |
| merged = json.loads(top_path.read_text()) | |
| except Exception: | |
| merged = {} | |
| if isinstance(merged, dict): | |
| merged.update(heartbeat_payload) | |
| merged["phase"] = "training" | |
| merged["destructive_actions_allowed"] = False | |
| top_tmp = top_path.with_suffix(".json.tmp") | |
| top_tmp.write_text(json.dumps(merged, indent=2, sort_keys=True) + "\n") | |
| top_tmp.replace(top_path) | |
| except Exception as exc: | |
| print(f"[heartbeat-json] warning: {exc}", flush=True) | |
| print( | |
| f"[heartbeat] phase={phase_name} pid={os.getpid()} step={step} " | |
| f"seen_tok={seen_tok} loss={loss_value:.3f} B={BATCH} L={BLOCK} " | |
| f"dblock={bool(getattr(args, 'dblock', False))} structured_masks={use_structured_masks(args)}{mem}", | |
| flush=True, | |
| ) | |
| last_heartbeat_mono = now_mono | |
| if val_batches and int(getattr(args, "val_every_sec", 0) or 0) > 0 and \ | |
| (time.monotonic() - last_val_mono) >= int(args.val_every_sec): | |
| _run_validation(core, ar_h, val_batches, args, step) | |
| last_val_mono = time.monotonic() | |
| _flush_sentinel = pathlib.Path(args.save_dir) / "FLUSH_NOW" | |
| if _flush_flag[0] or _flush_sentinel.exists(): | |
| _flush_flag[0] = False | |
| try: | |
| _flush_sentinel.unlink() | |
| except FileNotFoundError: | |
| pass | |
| _ck_name = f"{phase_name}_step{step:08d}.pt" | |
| _flush_delta() | |
| _disk_hygiene(pathlib.Path(args.save_dir), phase_name, args, reason="pre-flush-save") | |
| _prune_checkpoints(pathlib.Path(args.save_dir), phase_name, max_ckpts) | |
| save_ckpt(pathlib.Path(args.save_dir) / _ck_name, core, ar_h, sat_h, nat_h, opt, scaler, | |
| meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights}) | |
| last_save_mono = time.monotonic() | |
| _prune_deltas(pathlib.Path(args.save_dir), phase_name, args.delta_max_keep) | |
| last_delta_step = step | |
| print(f"[{phase_name}] ON-DEMAND flush saved {_ck_name} at step {step}") | |
| if args.save_every_sec > 0: | |
| now_mono = time.monotonic() | |
| if now_mono - last_save_mono >= args.save_every_sec: | |
| ck_name = f"{phase_name}_step{step:08d}.pt" | |
| _flush_delta() # wait for any in-flight delta before full save | |
| _disk_hygiene(pathlib.Path(args.save_dir), phase_name, args, reason="pre-save") | |
| _prune_checkpoints(pathlib.Path(args.save_dir), phase_name, max_ckpts) | |
| save_ckpt(pathlib.Path(args.save_dir) / ck_name, core, ar_h, sat_h, nat_h, opt, scaler, | |
| meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights}) | |
| last_save_mono = now_mono | |
| # Prune old deltas after a full save (they're superseded) | |
| _prune_deltas(pathlib.Path(args.save_dir), phase_name, args.delta_max_keep) | |
| last_delta_step = step # reset delta counter after full save | |
| # ── Delta checkpoint (step-based, weight-only, async) ── | |
| if args.delta_every_steps > 0 and (step - last_delta_step) >= args.delta_every_steps: | |
| save_root = pathlib.Path(args.save_dir) | |
| # AGILLM4 production runs on small rented disks. When keep=1, prune | |
| # old deltas before the async writer creates the next multi-GB file. | |
| if args.delta_max_keep and args.delta_max_keep > 0: | |
| _flush_delta() | |
| _prune_delta_files_to_count(save_root, phase_name, args.delta_max_keep - 1) | |
| save_delta(core, ar_h, sat_h, nat_h, step, seen_tok, save_root, phase_name) | |
| last_delta_step = step | |
| if args.auto_grow: | |
| steps_since_last_grow += 1 | |
| if steps_since_last_grow >= args.grow_every_steps: | |
| steps_since_last_grow = 0 | |
| try: | |
| idx = grow_plan.index(BLOCK) | |
| if idx + 1 < len(grow_plan): | |
| BLOCK = grow_plan[idx + 1] | |
| print(f"[{phase_name} Grow] Block -> {BLOCK}") | |
| if DEV.type == "cuda": torch.cuda.empty_cache() | |
| except ValueError: | |
| grow_plan = sorted(set(grow_plan + [BLOCK])) | |
| pbar.close() | |
| _flush_delta() # ensure any in-flight delta completes before final save | |
| if phase_name != "sft": | |
| save_ckpt(pathlib.Path(args.save_dir) / f"{phase_name}_final.pt", core, ar_h, sat_h, nat_h, opt, scaler, | |
| meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights}) | |
| else: | |
| print("[sft] Skipping duplicate sft_final.pt; final.pt will contain the SFT result.") | |
| return step, seen_tok, time.time() | |
| # ───────────────────────── Main Orchestrator ───────────────────────── | |
| def train(args): | |
| if getattr(args, "agillm3_compat", False): | |
| args.no_nat_head = True | |
| args.nat_every = 0 | |
| args.dblock_nat_weight = 0.0 | |
| args.dblock_nat_prob = 0.0 | |
| args.reinit_nat = False | |
| args.seed_nat_from_ar = False | |
| print(f"[agillm4.1] legacy compatibility mode: tokenizer={TOKENIZER_ID}, AR+SAT checkpoint schema, NAT disabled") | |
| cfg = PRESETS[args.preset].copy() | |
| tie_weights = args.tie_weights | |
| print_expansion_info(cfg, tie_weights) | |
| if not args.fresh: | |
| if args.warmstart_from: | |
| src_probe = pathlib.Path(args.warmstart_from) | |
| elif args.resume: | |
| src_probe = pathlib.Path(args.resume) | |
| else: | |
| src_probe = pathlib.Path(args.save_dir) / "final.pt" | |
| prev_cfg = infer_cfg_from_ckpt(src_probe) | |
| else: prev_cfg = None | |
| if prev_cfg: | |
| cfg.update({k: v for k, v in prev_cfg.items() if k in cfg}) | |
| if args.x2 and prev_cfg.get("layers"): cfg["layers"] = max(cfg["layers"], prev_cfg["layers"] * 2) | |
| if args.rank: cfg["rank"] = args.rank | |
| if args.x2 and not prev_cfg: cfg["layers"] *= 2 | |
| prev_moe = prev_cfg if isinstance(prev_cfg, dict) else {} | |
| if bool(getattr(args, "tie_kv", False)): | |
| cfg["tie_kv"] = True | |
| requested_moe = bool(getattr(args, "moe_ffn", DEFAULT_MOE_FFN)) | |
| if requested_moe or bool(prev_moe.get("moe_ffn", False)): | |
| cfg["moe_ffn"] = True | |
| cfg["moe_experts"] = int(getattr(args, "moe_experts", DEFAULT_MOE_EXPERTS) if requested_moe else prev_moe.get("moe_experts", DEFAULT_MOE_EXPERTS)) | |
| cfg["moe_top_k"] = int(getattr(args, "moe_top_k", DEFAULT_MOE_TOP_K) if requested_moe else prev_moe.get("moe_top_k", DEFAULT_MOE_TOP_K)) | |
| cfg["moe_mlp_mult"] = int(getattr(args, "moe_mlp_mult", DEFAULT_MOE_MLP_MULT) if requested_moe else prev_moe.get("moe_mlp_mult", DEFAULT_MOE_MLP_MULT)) | |
| cfg["moe_shared_experts"] = int(getattr(args, "moe_shared_experts", 0) if requested_moe else prev_moe.get("moe_shared_experts", 0)) | |
| cfg["moe_shared_mlp_mult"] = int(getattr(args, "moe_shared_mlp_mult", 0) if requested_moe else prev_moe.get("moe_shared_mlp_mult", 0)) | |
| else: | |
| cfg["moe_ffn"] = False | |
| use_nat_head = not bool(getattr(args, "no_nat_head", False)) | |
| if not use_nat_head: | |
| cfg["nat_head"] = False | |
| args.nat_every = 0 | |
| args.dblock_nat_weight = 0.0 | |
| args.dblock_nat_prob = 0.0 | |
| print(f"Config: {cfg}") | |
| print( | |
| "AGILLM4.1 single-file runtime: " | |
| f"attn_backend={args.attn_backend} grad_checkpoint={args.grad_checkpoint} " | |
| f"sublinear_window={args.sublinear_window} sublinear_stride={args.sublinear_stride} " | |
| f"sublinear_max_anchors={args.sublinear_max_anchors} sublinear_chunk={args.sublinear_chunk} " | |
| f"sublinear_sinks={args.sublinear_sinks} sublinear_recent_anchors={args.sublinear_recent_anchors} " | |
| f"sublinear_pooled_landmarks={args.sublinear_pooled_landmarks} " | |
| f"moe_ffn={cfg.get('moe_ffn', False)} moe_experts={cfg.get('moe_experts', 0)} " | |
| f"moe_top_k={cfg.get('moe_top_k', 0)} moe_mlp_mult={cfg.get('moe_mlp_mult', 0)}" | |
| ) | |
| core = Encoder( | |
| cfg, | |
| tie_weights=tie_weights, | |
| attn_backend=args.attn_backend, | |
| grad_checkpoint=args.grad_checkpoint, | |
| sublinear_window=args.sublinear_window, | |
| sublinear_stride=args.sublinear_stride, | |
| sublinear_max_anchors=args.sublinear_max_anchors, | |
| sublinear_chunk=args.sublinear_chunk, | |
| sublinear_sinks=args.sublinear_sinks, | |
| sublinear_recent_anchors=args.sublinear_recent_anchors, | |
| sublinear_pooled_landmarks=args.sublinear_pooled_landmarks, | |
| anchor_memory=getattr(args, "anchor_memory", DEFAULT_ANCHOR_MEMORY), | |
| anchor_stride=getattr(args, "anchor_stride", DEFAULT_ANCHOR_STRIDE), | |
| anchor_max=getattr(args, "anchor_max", DEFAULT_ANCHOR_MAX), | |
| anchor_position=getattr(args, "anchor_position", DEFAULT_ANCHOR_POSITION), | |
| ).to(DEV) | |
| ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV) | |
| sat_h = SATHead(cfg["d"], mode="var", tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV) | |
| nat_h = NATHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(DEV) if use_nat_head else None | |
| total_params = _count_enabled_params(core, ar_h, sat_h, nat_h) | |
| print(f"Total parameters: {total_params:,}") | |
| if tie_weights: | |
| head_names = "AR/SAT/NAT" if nat_h is not None else "AR/SAT" | |
| print(f"{Colors.WARN}[weight-tying] Embedding and {head_names} vocab projections share one tensor (VRAM-first){Colors.RESET}") | |
| if not args.fresh: | |
| src = pathlib.Path(args.warmstart_from) if args.warmstart_from else pathlib.Path(args.save_dir) / "final.pt" | |
| src = _resolve_ckpt(src) | |
| if src: | |
| loaded = _safe_load_any(src, core, key="core") | |
| _safe_load_any(src, ar_h, key="ar") | |
| _safe_load_any(src, sat_h, key="sat") | |
| nat_loaded = _safe_load_any(src, nat_h, key="nat") if nat_h is not None else 0 | |
| if nat_h is not None and not nat_loaded: | |
| print("[nat] Warm-start source has no NAT head; NAT head initialized fresh") | |
| if loaded: print(f"Warm-start loaded from {src}") | |
| _phase_freeze(core, freeze_core=args.freeze_core, unfreeze_ln=args.unfreeze_ln, train_emb=args.train_emb) | |
| opt = make_optimizer(args, core, ar_h, sat_h, args.lr_core, args.lr_head, nat_h) | |
| scaler = GradScaler(enabled=(args.amp and _needs_grad_scaler())) | |
| start_step, seen_tok, last_wall = 0, 0, None | |
| if args.resume_delta and not args.fresh: | |
| delta_step, delta_tok = load_delta(pathlib.Path(args.resume_delta), core, ar_h, sat_h, nat_h) | |
| start_step, seen_tok, last_wall = delta_step, delta_tok, None | |
| print(f"Resumed from DELTA at step {start_step} (optimizer state reset — momentum rebuilds in ~100 steps)") | |
| elif args.resume and not args.fresh: | |
| start_step, seen_tok, last_wall = load_ckpt(pathlib.Path(args.resume), core, ar_h, sat_h, opt, scaler, nat_h) | |
| print(f"Resumed from step {start_step}") | |
| if getattr(args, "seed_nat_from_ar", False) and nat_h is not None and ar_h is not None: | |
| # Seed the non-autoregressive (NAT) head from the trained AR head ("father"). | |
| # Same hidden->vocab projection shape, so NAT starts knowing the token | |
| # distribution instead of from random/blank -> faster, no collapse. | |
| with torch.no_grad(): | |
| nat_h.proj.weight.copy_(ar_h.proj.weight) | |
| if nat_h.proj.bias is not None: | |
| if getattr(ar_h.proj, "bias", None) is not None: | |
| nat_h.proj.bias.copy_(ar_h.proj.bias) | |
| else: | |
| nat_h.proj.bias.zero_() | |
| print("[nat] Seeded NAT head from the AR head ('father') for the mask-predict objective") | |
| elif getattr(args, "reinit_nat", False) and nat_h is not None: | |
| for _m in nat_h.modules(): | |
| if isinstance(_m, nn.Linear): | |
| nn.init.normal_(_m.weight, mean=0.0, std=0.02) | |
| if _m.bias is not None: | |
| nn.init.zeros_(_m.bias) | |
| print("[nat] Reinitialized NAT head weights (random) for the mask-predict objective") | |
| # torch.compile AFTER loading checkpoint (key names differ) | |
| if args.compile: | |
| print("[torch.compile] Compiling model...") | |
| core = torch.compile(core, mode="reduce-overhead") | |
| ar_h = torch.compile(ar_h, mode="reduce-overhead") | |
| sat_h = torch.compile(sat_h, mode="reduce-overhead") | |
| if nat_h is not None: | |
| nat_h = torch.compile(nat_h, mode="reduce-overhead") | |
| print("[torch.compile] Done.") | |
| step, seen_tok, last_wall = _train_phase( | |
| args, "pretrain", core, ar_h, sat_h, nat_h, opt, scaler, | |
| start_step, seen_tok, last_wall, cfg, | |
| args.source, args.steps, | |
| args.block or DEFAULT_BLOCK, | |
| args.batch_size or DEFAULT_BATCH, | |
| chat_cfg={"chat": args.chat, "key": args.chat_messages_key, "gen_prompt": args.sft_add_generation_prompt, "text_field": args.dataset_field_text}, | |
| max_ckpts=args.max_ckpts, | |
| target_tokens_override=args.target_tokens, | |
| tie_weights=tie_weights | |
| ) | |
| if (not args.after_sft_source) and (args.after_sft_steps and args.after_sft_steps > 0): | |
| args.after_sft_source = DEFAULT_AFTER_SFT_SOURCES | |
| args.after_sft_chat = True | |
| if args.after_sft_add_generation_prompt is None: args.after_sft_add_generation_prompt = True | |
| if not args.after_sft_block: args.after_sft_block = DEFAULT_AFTER_SFT_BLOCK | |
| if args.after_sft_source and args.after_sft_steps and args.after_sft_steps > 0: | |
| print("\n[Orchestrator] Starting Post-Pretraining SFT Phase...") | |
| _phase_freeze(core, | |
| freeze_core=args.after_sft_freeze_core, | |
| unfreeze_ln=args.after_sft_unfreeze_ln, | |
| train_emb=args.after_sft_train_emb) | |
| opt = make_optimizer( | |
| args, | |
| core, | |
| ar_h, | |
| sat_h, | |
| args.after_sft_lr_core or args.lr_core, | |
| args.after_sft_lr_head or args.lr_head, | |
| nat_h, | |
| ) | |
| step, seen_tok, last_wall = _train_phase( | |
| args, "sft", core, ar_h, sat_h, nat_h, opt, scaler, | |
| step, seen_tok, last_wall, cfg, | |
| args.after_sft_source, args.after_sft_steps, | |
| args.after_sft_block or DEFAULT_AFTER_SFT_BLOCK, | |
| args.batch_size or DEFAULT_BATCH, | |
| chat_cfg={ | |
| "chat": args.after_sft_chat, | |
| "key": args.after_sft_chat_messages_key, | |
| "gen_prompt": args.after_sft_add_generation_prompt if args.after_sft_add_generation_prompt is not None else args.sft_add_generation_prompt, | |
| "text_field": args.after_sft_dataset_field_text | |
| }, | |
| max_ckpts=args.max_ckpts, | |
| target_tokens_override=None, | |
| tie_weights=tie_weights, | |
| streaming=True | |
| ) | |
| save_ckpt(pathlib.Path(args.save_dir) / "final.pt", core, ar_h, sat_h, nat_h, opt, scaler, | |
| meta={"cfg": cfg, "step": step, "seen_tok": seen_tok, "wall_time": time.time(), "tie_weights": tie_weights}) | |
| print("🎉 All Training Complete") | |
| # ───────────────────────── Sampling ───────────────────────── | |
| def _apply_penalties(logits, ids, n, rep_p, pres_p, freq_p): | |
| if ids.numel() == 0: return logits | |
| hist = ids[0, -n:].long() if n > 0 else ids[0].long() | |
| uniq, counts = torch.unique(hist, return_counts=True) | |
| if pres_p or freq_p: | |
| logits[..., uniq] -= (pres_p + freq_p * counts.float()) | |
| if rep_p != 1.0: | |
| sel = logits[..., uniq] | |
| logits[..., uniq] = torch.where(sel > 0, sel / rep_p, sel * rep_p) | |
| return logits | |
| def _suppress_eos(logits, args, force=False): | |
| if (force or getattr(args, "ignore_eos", False)) and EOS is not None: | |
| logits = logits.clone() | |
| logits[..., int(EOS)] = -1e9 | |
| return logits | |
| def _sample(logits, T, top_k, top_p, min_p, greedy): | |
| if greedy: return logits.argmax(-1, keepdim=True) | |
| probs = (logits / max(T, 1e-8)).softmax(-1) | |
| if top_k: | |
| v, i = torch.topk(probs, min(top_k, probs.size(-1))) | |
| probs = torch.zeros_like(probs).scatter_(-1, i, v) | |
| if top_p < 1.0: | |
| s_probs, s_idx = torch.sort(probs, descending=True, dim=-1) | |
| probs = torch.zeros_like(probs).scatter_(-1, s_idx, s_probs * (torch.cumsum(s_probs, -1) <= top_p).float()) | |
| if min_p > 0: probs[probs < min_p] = 0 | |
| if probs.sum() == 0: return logits.argmax(-1, keepdim=True) | |
| return probs.div_(probs.sum()).multinomial(1) | |
| def _dblock_block_layers(core, dblock_blocks): | |
| L = len(core.blocks) | |
| B = max(1, int(dblock_blocks)) | |
| per = max(1, L // B) | |
| groups = [] | |
| for b in range(B): | |
| lo = b * per | |
| hi = L if b == B - 1 else (b + 1) * per | |
| groups.append(list(range(lo, hi))) | |
| return groups | |
| def _dblock_select_block(sigma, bsig): | |
| for b in range(len(bsig) - 1): | |
| if bsig[b] <= sigma <= bsig[b + 1]: | |
| return b | |
| return 0 if sigma < bsig[0] else len(bsig) - 2 | |
| def _block_stream_enabled(args) -> bool: | |
| return bool(getattr(args, "block_stream", False)) | |
| def _block_stream_compute_device(args=None): | |
| return DEV | |
| def _moe_expert_stream_enabled(args) -> bool: | |
| return bool(getattr(args, "moe_expert_stream", False)) | |
| def _dtype_from_arg(args, attr: str, flag: str): | |
| name = str(getattr(args, attr, "fp32") or "fp32").lower() | |
| if name in {"fp32", "float32", "none"}: | |
| return None | |
| if name in {"fp16", "float16", "half"}: | |
| return torch.float16 | |
| if name in {"bf16", "bfloat16"}: | |
| return torch.bfloat16 | |
| raise ValueError(f"unsupported {flag} {name!r}") | |
| def _block_stream_dtype(args): | |
| return _dtype_from_arg(args, "block_stream_dtype", "--block_stream_dtype") | |
| def _infer_dtype(args): | |
| return _dtype_from_arg(args, "infer_dtype", "--infer_dtype") | |
| def _block_stream_empty_cache(args) -> bool: | |
| return bool(getattr(args, "block_stream_empty_cache", True)) and torch.cuda.is_available() | |
| def _block_stream_kv_cache_enabled(args) -> bool: | |
| return bool(getattr(args, "block_stream_kv_cache", True)) | |
| def _block_stream_cache_pages_mode(args): | |
| explicit = getattr(args, "block_stream_cache_pages", None) | |
| if explicit is None: | |
| return "auto" | |
| return "on" if bool(explicit) else "off" | |
| def _block_stream_cache_pages_enabled(args) -> bool: | |
| effective = getattr(args, "_block_stream_cache_pages_effective", None) | |
| if effective is not None: | |
| return bool(effective) | |
| return _block_stream_cache_pages_mode(args) == "on" | |
| def _module_tensor_bytes(mod) -> int: | |
| total = 0 | |
| for t in list(mod.parameters(recurse=True)) + list(mod.buffers(recurse=True)): | |
| total += int(t.numel()) * int(t.element_size()) | |
| return total | |
| def _configure_block_stream_page_cache(args, core): | |
| mode = _block_stream_cache_pages_mode(args) | |
| if mode == "off": | |
| args._block_stream_cache_pages_effective = False | |
| args._block_stream_cache_pages_reason = "explicit-off" | |
| return | |
| if mode == "on": | |
| args._block_stream_cache_pages_effective = True | |
| args._block_stream_cache_pages_reason = "explicit-on" | |
| return | |
| if not torch.cuda.is_available() or DEV.type != "cuda": | |
| args._block_stream_cache_pages_effective = False | |
| args._block_stream_cache_pages_reason = "auto-no-cuda" | |
| return | |
| try: | |
| device_index = DEV.index if getattr(DEV, "index", None) is not None else torch.cuda.current_device() | |
| free, total = torch.cuda.mem_get_info(device_index) | |
| except (TypeError, ValueError): | |
| free, total = torch.cuda.mem_get_info() | |
| page_bytes = sum(_module_tensor_bytes(blk) for blk in core.blocks) | |
| allocated = torch.cuda.memory_allocated() | |
| reserved = torch.cuda.memory_reserved() | |
| reusable = max(0, int(reserved) - int(allocated)) | |
| usable = int(free) + int(reusable) | |
| # This is an incremental fit check, not total model size. At this point the | |
| # embedding, heads, CUDA context, and allocator slabs are already resident; | |
| # measured page-cache peak is lower than raw block parameter bytes + safety. | |
| effective_page_bytes = int(page_bytes * 0.75) | |
| safety = max(128 * 1024 * 1024, int(total * 0.005)) | |
| effective_need = effective_page_bytes + int(safety) | |
| enabled = int(usable) > int(effective_need) | |
| args._block_stream_cache_pages_effective = bool(enabled) | |
| args._block_stream_cache_pages_reason = ( | |
| f"auto usable={usable/1e9:.2f}GB free={free/1e9:.2f}GB " | |
| f"reuse={reusable/1e9:.2f}GB need={effective_need/1e9:.2f}GB raw={page_bytes/1e9:.2f}GB" | |
| ) | |
| def _block_stream_kv_store_device(args): | |
| name = str(getattr(args, "block_stream_kv_device", "cuda") or "cuda").lower() | |
| if name in {"cuda", "gpu"} and torch.cuda.is_available(): | |
| return DEV | |
| return torch.device("cpu") | |
| def _block_stream_kv_to_device(kv, device): | |
| if kv is None or isinstance(kv, KVBuffer): | |
| return kv | |
| k, v = kv | |
| if k.device == device and v.device == device: | |
| return kv | |
| return (k.to(device, non_blocking=True), v.to(device, non_blocking=True)) | |
| def _block_stream_kv_to_store(kv, device): | |
| if kv is None or isinstance(kv, KVBuffer): | |
| return kv | |
| k, v = kv | |
| if device.type == "cpu": | |
| return (k.detach().to("cpu", non_blocking=True), v.detach().to("cpu", non_blocking=True)) | |
| return (k.detach(), v.detach()) | |
| def _block_stream_layer_pages(core, args): | |
| page_layers = int(getattr(args, "block_stream_page_layers", 1) or 0) | |
| if page_layers <= 0: | |
| return _dblock_block_layers(core, int(getattr(args, "dblock_blocks", 4) or 4)) | |
| page_layers = max(1, page_layers) | |
| return [list(range(i, min(i + page_layers, len(core.blocks)))) for i in range(0, len(core.blocks), page_layers)] | |
| def _block_stream_release(mod, args): | |
| mod.to("cpu") | |
| if _block_stream_empty_cache(args): | |
| torch.cuda.empty_cache() | |
| def _block_stream_load_block(block, device, args): | |
| if _moe_expert_stream_enabled(args) and isinstance(getattr(block, "ff", None), MoEFFN): | |
| block.ln1.to(device) | |
| block.ln2.to(device) | |
| block.mha.to(device) | |
| block.ff.router.to(device) | |
| if block.ff.shared is not None: | |
| block.ff.shared.to(device) | |
| for expert in block.ff.experts: | |
| expert.to("cpu") | |
| block.ff.set_expert_stream(True, bool(getattr(args, "moe_expert_stream_empty_cache", True))) | |
| return block | |
| return block.to(device) | |
| def _block_stream_release_block(block, args): | |
| if _block_stream_cache_pages_enabled(args): | |
| return | |
| if isinstance(getattr(block, "ff", None), MoEFFN): | |
| block.ff.set_expert_stream(False, bool(getattr(args, "moe_expert_stream_empty_cache", True))) | |
| block.to("cpu") | |
| if _block_stream_empty_cache(args): | |
| torch.cuda.empty_cache() | |
| def _moe_expert_stream_stats(core): | |
| loads = 0 | |
| tokens = 0 | |
| for mod in core.modules(): | |
| if isinstance(mod, MoEFFN): | |
| st = getattr(mod, "expert_stream_stats", None) or {} | |
| loads += int(st.get("loads", 0)) | |
| tokens += int(st.get("tokens", 0)) | |
| return loads, tokens | |
| def _moe_expert_stream_reset_stats(core): | |
| for mod in core.modules(): | |
| if isinstance(mod, MoEFFN): | |
| mod.expert_stream_stats = {"loads": 0, "tokens": 0} | |
| def _block_stream_maybe_anchor(core, layer_idx, x, args): | |
| if core.anchor is None or layer_idx != core.anchor_position: | |
| return x | |
| device = _block_stream_compute_device(args) | |
| core.anchor.to(device) | |
| x, _ = core.anchor(x) | |
| _block_stream_release(core.anchor, args) | |
| return x | |
| def _block_stream_forward(core, ids, mask, args): | |
| """Run Encoder.forward while paging blocks through the compute device.""" | |
| device = _block_stream_compute_device(args) | |
| core.emb.to(device) | |
| core.ln.to(device) | |
| ids = ids.to(device) | |
| x = core.emb(ids) | |
| for page in _block_stream_layer_pages(core, args): | |
| resident = [_block_stream_load_block(core.blocks[li], device, args) for li in page] | |
| try: | |
| for li, blk in zip(page, resident): | |
| x = _run_block(blk, x, mask, False, args) | |
| x = _block_stream_maybe_anchor(core, li, x, args) | |
| finally: | |
| for blk in resident: | |
| _block_stream_release_block(blk, args) | |
| return core.ln(x) | |
| def _block_stream_forward_cached(core, ids, mask, kv_caches, total_seq_len, args): | |
| """Block-stream AR/SAT decode with KV cache. | |
| We still page layer weights through the compute device, but avoid recomputing | |
| the full prefix for every emitted token. KV tensors can stay on CUDA for speed | |
| or be stored on CPU for the lowest resident VRAM. | |
| """ | |
| device = _block_stream_compute_device(args) | |
| kv_store_device = _block_stream_kv_store_device(args) | |
| core.emb.to(device) | |
| core.ln.to(device) | |
| ids = ids.to(device) | |
| x = core.emb(ids) | |
| new_kvs = [None] * len(core.blocks) | |
| for page in _block_stream_layer_pages(core, args): | |
| resident = [_block_stream_load_block(core.blocks[li], device, args) for li in page] | |
| try: | |
| for li, blk in zip(page, resident): | |
| kv = kv_caches[li] if kv_caches else None | |
| kv = _block_stream_kv_to_device(kv, device) | |
| x, kv_out = blk(x, mask, kv, use_cache=True, total_seq_len=total_seq_len) | |
| x = _block_stream_maybe_anchor(core, li, x, args) | |
| new_kvs[li] = _block_stream_kv_to_store(kv_out, kv_store_device) | |
| finally: | |
| for blk in resident: | |
| _block_stream_release_block(blk, args) | |
| return core.ln(x), new_kvs | |
| def _edm_denoise_block(core, layers, z, sigma_t, mask, args): | |
| cs, co, ci = _edm_pre(sigma_t) | |
| h = ci * z | |
| if _block_stream_enabled(args): | |
| device = _block_stream_compute_device(args) | |
| for li in layers: | |
| blk = _block_stream_load_block(core.blocks[li], device, args) | |
| try: | |
| h = _run_block(blk, h, mask, False, args) | |
| h = _block_stream_maybe_anchor(core, li, h, args) | |
| finally: | |
| _block_stream_release_block(blk, args) | |
| else: | |
| for li in layers: | |
| h = _run_block(core.blocks[li], h, mask, False, args) | |
| return cs * z + co * h | |
| def _dblock_euler_hidden(core, ids, args): | |
| """DiffusionBlocks EDM Euler block-chain hidden state (faithful reverse ODE), | |
| adapted to agillm4.1's causal AR head. --euler_start_sigma tunes context | |
| conditioning (SDEdit-style); returns LayerNorm'd hidden [B,T,d].""" | |
| import numpy as _np | |
| dblock_blocks = int(getattr(args, "dblock_blocks", 4) or 4) | |
| steps = max(dblock_blocks, int(getattr(args, "euler_steps", 0) or (dblock_blocks * 2))) | |
| bsig = _block_sigmas(dblock_blocks) | |
| groups = _dblock_block_layers(core, dblock_blocks) | |
| sigma_min = float(bsig[0]) | |
| start = float(getattr(args, "euler_start_sigma", 0.0) or 0.0) | |
| if start <= 0.0: | |
| start = float(bsig[-1]) | |
| start = max(start, sigma_min * 2) | |
| mask = causal_mask(ids.size(1), structured=use_structured_masks(args)) | |
| e = core.emb(ids) | |
| lo, hi = math.log(sigma_min), math.log(start) | |
| sched = [float(_np.exp(hi + (lo - hi) * (i / steps))) for i in range(steps + 1)] | |
| z = e + sched[0] * torch.randn_like(e) | |
| with amp(getattr(args, "amp", False)): | |
| for i in range(steps): | |
| s_cur, s_next = sched[i], sched[i + 1] | |
| b = _dblock_select_block(s_cur, bsig) | |
| sig_t = torch.full((ids.size(0),), s_cur, device=ids.device, dtype=z.dtype) | |
| D = _edm_denoise_block(core, groups[b], z, sig_t, mask, args) | |
| z = z + ((s_next - s_cur) / s_cur) * (z - D) | |
| sig0 = torch.full((ids.size(0),), sigma_min, device=ids.device, dtype=z.dtype) | |
| D0 = _edm_denoise_block(core, groups[0], z, sig0, mask, args) | |
| return core.ln(D0) | |
| def infer(args): | |
| if args.mode == "ar": | |
| if args.temperature is None: args.temperature = 0.7 | |
| if args.top_k is None: args.top_k = 0 | |
| if args.repetition_penalty is None: args.repetition_penalty = 1.3 | |
| if args.presence_penalty is None: args.presence_penalty = 0.0 | |
| if args.frequency_penalty is None: args.frequency_penalty = 0.3 | |
| if args.penalty_last_n is None: args.penalty_last_n = 128 | |
| if args.var is None: args.var = False | |
| elif args.mode == "sat": | |
| if args.temperature is None: args.temperature = 0.5 | |
| if args.top_k is None: args.top_k = 30 | |
| if args.repetition_penalty is None: args.repetition_penalty = 2.0 | |
| if args.presence_penalty is None: args.presence_penalty = 0.6 | |
| if args.frequency_penalty is None: args.frequency_penalty = 1.0 | |
| if args.penalty_last_n is None: args.penalty_last_n = 200 | |
| if args.var is None: args.var = True | |
| else: | |
| if args.temperature is None: args.temperature = 0.8 | |
| if args.top_k is None: args.top_k = 50 | |
| if args.repetition_penalty is None: args.repetition_penalty = 1.6 | |
| if args.presence_penalty is None: args.presence_penalty = 0.6 | |
| if args.frequency_penalty is None: args.frequency_penalty = 1.0 | |
| if args.penalty_last_n is None: args.penalty_last_n = 512 | |
| if args.var is None: args.var = False | |
| min_new = int(getattr(args, "min_new", 0) or 0) | |
| if args.mode == "sat": | |
| min_new = max(min_new, SAT_BLOCK) | |
| path = _resolve_ckpt(pathlib.Path(args.ckpt)) or pathlib.Path(args.ckpt) | |
| sd = torch.load(path, map_location="cpu") | |
| # Inference never needs optimizer/scaler state. Drop it before model construction | |
| # so block-stream runs keep CPU RAM pressure lower after checkpoint load. | |
| if isinstance(sd, dict): | |
| sd.pop("opt", None) | |
| sd.pop("scaler", None) | |
| import gc as _gc | |
| _gc.collect() | |
| # Restore tokenizer from checkpoint (embedded json preferred; never raises) | |
| _restore_tokenizer_from_ckpt(sd, path) | |
| # Warn if transformers version changed since checkpoint was saved | |
| if "transformers_version" in sd: | |
| import transformers as _tf | |
| if sd["transformers_version"] != _tf.__version__: | |
| print(f"[tokenizer] WARNING: checkpoint saved with transformers={sd['transformers_version']}, now running {_tf.__version__}") | |
| # Handle delta checkpoints (weight-only, no cfg) | |
| if sd.get("delta"): | |
| print("[infer] Delta checkpoint detected, using large preset cfg") | |
| cfg = PRESETS["large"].copy() | |
| tie_weights = False | |
| # Remap: delta stores under sd["weights"]["core"/"ar"/"sat"/"nat"] | |
| sd["core"] = sd["weights"]["core"] | |
| sd["ar"] = sd["weights"]["ar"] | |
| sd["sat"] = sd["weights"]["sat"] | |
| if "nat" in sd["weights"]: | |
| sd["nat"] = sd["weights"]["nat"] | |
| else: | |
| cfg = sd["cfg"] | |
| tie_weights = sd.get("tie_weights", False) | |
| plain_output = ( | |
| bool(getattr(args, "plain_output", False)) | |
| or bool(getattr(args, "claude_friendly", False)) | |
| or not sys.stdout.isatty() | |
| ) | |
| uk_time = get_uk_time() | |
| ckpt_name = path.name | |
| if plain_output: | |
| print(f"[infer] inference_time={uk_time}") | |
| print(f"[infer] checkpoint={ckpt_name}") | |
| else: | |
| print(f"┌─────────────────────────────────────────────────┐") | |
| print(f"│ INFERENCE @ {uk_time:<35s} │") | |
| print(f"├─────────────────────────────────────────────────┤") | |
| print(f"│ Checkpoint: {ckpt_name:<35s} │") | |
| print(f"└─────────────────────────────────────────────────┘") | |
| print_expansion_info(cfg, tie_weights, plain=plain_output) | |
| block_stream = _block_stream_enabled(args) | |
| infer_dtype = None if block_stream else _infer_dtype(args) | |
| resident_dtype = (infer_dtype is not None and not block_stream) | |
| core_device = torch.device("cpu") if (block_stream or resident_dtype) else DEV | |
| core = Encoder( | |
| cfg, | |
| tie_weights=tie_weights, | |
| attn_backend=args.attn_backend, | |
| sublinear_window=args.sublinear_window, | |
| sublinear_stride=args.sublinear_stride, | |
| sublinear_max_anchors=args.sublinear_max_anchors, | |
| sublinear_chunk=args.sublinear_chunk, | |
| sublinear_sinks=args.sublinear_sinks, | |
| sublinear_recent_anchors=args.sublinear_recent_anchors, | |
| sublinear_pooled_landmarks=args.sublinear_pooled_landmarks, | |
| anchor_memory=getattr(args, "anchor_memory", DEFAULT_ANCHOR_MEMORY), | |
| anchor_stride=getattr(args, "anchor_stride", DEFAULT_ANCHOR_STRIDE), | |
| anchor_max=getattr(args, "anchor_max", DEFAULT_ANCHOR_MAX), | |
| anchor_position=getattr(args, "anchor_position", DEFAULT_ANCHOR_POSITION), | |
| ).to(core_device) | |
| head_device = torch.device("cpu") if resident_dtype else DEV | |
| ar_h = ARHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(head_device) | |
| sat_head_mlp = bool(sd.get("sat_head_mlp", False) or _sat_head_mlp_from_state(sd)) | |
| sat_h = SATHead(cfg["d"], mlp=sat_head_mlp, tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(head_device) | |
| nat_h = NATHead(cfg["d"], tie_weights=tie_weights, embedding_weight=core.emb.weight if tie_weights else None).to(head_device) if ("nat" in sd or args.mode == "nat") else None | |
| core.load_state_dict(_prepare_core_state_dict_for_load(core, sd["core"])) | |
| ar_h.load_state_dict(sd["ar"]) | |
| _load_infer_head_state(sat_h, sd["sat"], "SATHead") | |
| if nat_h is not None: | |
| if "nat" not in sd: | |
| raise ValueError("NAT inference requested, but this checkpoint has no NAT head") | |
| _load_infer_head_state(nat_h, sd["nat"], "NATHead") | |
| core.eval() | |
| ar_h.eval() | |
| sat_h.eval() | |
| if nat_h is not None: | |
| nat_h.eval() | |
| if resident_dtype: | |
| core.to(dtype=infer_dtype) | |
| ar_h.to(dtype=infer_dtype) | |
| sat_h.to(dtype=infer_dtype) | |
| if nat_h is not None: | |
| nat_h.to(dtype=infer_dtype) | |
| core.to(DEV) | |
| ar_h.to(DEV) | |
| sat_h.to(DEV) | |
| if nat_h is not None: | |
| nat_h.to(DEV) | |
| print(f"[infer] infer_dtype={str(infer_dtype).replace('torch.', '')} resident=True device={DEV}") | |
| if block_stream: | |
| stream_dtype = _block_stream_dtype(args) | |
| if stream_dtype is not None: | |
| core.to(dtype=stream_dtype) | |
| ar_h.to(dtype=stream_dtype) | |
| sat_h.to(dtype=stream_dtype) | |
| if nat_h is not None: | |
| nat_h.to(dtype=stream_dtype) | |
| print(f"[infer] block_stream_dtype={str(stream_dtype).replace('torch.', '')}") | |
| core.emb.to(DEV) | |
| core.ln.to(DEV) | |
| if core.anchor is not None: | |
| core.anchor.to("cpu") | |
| for blk in core.blocks: | |
| blk.to("cpu") | |
| if _block_stream_empty_cache(args): | |
| torch.cuda.empty_cache() | |
| _configure_block_stream_page_cache(args, core) | |
| page_desc = "dblock" if int(getattr(args, "block_stream_page_layers", 1) or 0) <= 0 else f"{int(getattr(args, 'block_stream_page_layers', 1))} layer(s)" | |
| moe_desc = " moe_expert_stream=True" if _moe_expert_stream_enabled(args) else "" | |
| page_cache_reason = getattr(args, "_block_stream_cache_pages_reason", "") | |
| page_cache_desc = f" page_cache={_block_stream_cache_pages_enabled(args)}" | |
| if page_cache_reason: | |
| page_cache_desc += f" ({page_cache_reason})" | |
| if _block_stream_kv_cache_enabled(args): | |
| kv_desc = f" KV cache=True kv_device={_block_stream_kv_store_device(args)}" | |
| else: | |
| kv_desc = " KV cache=False full-prefix recompute=True" | |
| print(f"[infer] block_stream=True device={DEV} page={page_desc}{moe_desc};{page_cache_desc}{kv_desc}") | |
| if _moe_expert_stream_enabled(args): | |
| _moe_expert_stream_reset_stats(core) | |
| total_params = _count_enabled_params(core, ar_h, sat_h, nat_h) | |
| if total_params >= 1_000_000_000: | |
| param_str = f"{total_params / 1_000_000_000:.2f}B" | |
| elif total_params >= 1_000_000: | |
| param_str = f"{total_params / 1_000_000:.2f}M" | |
| elif total_params >= 1_000: | |
| param_str = f"{total_params / 1_000:.2f}K" | |
| else: | |
| param_str = f"{total_params}" | |
| print(f"Model size: {param_str} parameters ({total_params:,})") | |
| prompt_tokens = tok.encode(args.prompt) | |
| prompt_len = len(prompt_tokens) | |
| ids = torch.tensor([prompt_tokens], device=DEV) | |
| if ids.size(1) == 0: | |
| ids = torch.tensor([[EOS]], device=DEV) | |
| prompt_len = 1 | |
| mode_str = args.mode | |
| if args.mode == "sat": | |
| mode_str = f"sat-{'var' if args.var else 'fixed'}" | |
| if plain_output: | |
| print(f"Generating ({mode_str})...") | |
| else: | |
| print(f"{Colors.INFO}Generating ({mode_str})...{Colors.RESET}") | |
| if (block_stream or resident_dtype) and torch.cuda.is_available(): | |
| torch.cuda.reset_peak_memory_stats() | |
| start = time.time() | |
| if args.mode == "ar": | |
| _euler = getattr(args, "sampler", "ar") == "euler" | |
| block_stream_kv = block_stream and _block_stream_kv_cache_enabled(args) | |
| kvs = None | |
| if not _euler and block_stream_kv: | |
| h, kvs = _block_stream_forward_cached( | |
| core, | |
| ids, | |
| causal_mask(ids.size(1), structured=use_structured_masks(args)), | |
| None, | |
| ids.size(1), | |
| args, | |
| ) | |
| elif not _euler and not block_stream: | |
| h, kvs = core(ids, causal_mask(ids.size(1), structured=use_structured_masks(args)), use_cache=True, total_seq_len=ids.size(1)) | |
| for _ in range(args.max_new): | |
| if _euler: | |
| h = _dblock_euler_hidden(core, ids, args) | |
| elif block_stream and not block_stream_kv: | |
| h = _block_stream_forward(core, ids, causal_mask(ids.size(1), structured=use_structured_masks(args)), args) | |
| logits = ar_h(h)[:, -1].float() | |
| logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty) | |
| logits = _suppress_eos(logits, args) | |
| nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy) | |
| ids = torch.cat([ids, nxt], 1) | |
| if EOS is not None and not getattr(args, "ignore_eos", False) and int(nxt.item()) == int(EOS): | |
| break | |
| if not _euler: | |
| if block_stream_kv: | |
| h, kvs = _block_stream_forward_cached(core, ids[:, -1:], None, kvs, ids.size(1), args) | |
| elif not block_stream: | |
| h, kvs = core(ids[:, -1:], None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1)) | |
| elif args.mode == "nat": | |
| # Iterative mask-predict decode (CMLM): keep the prompt fixed and fill the | |
| # BLANK slots, committing confident predictions each pass. Unlike the | |
| # original straight argmax path, this applies the same anti-repetition | |
| # penalties and sampler used by AR/SAT at each committed position. | |
| n_fill = max(1, int(args.max_new)) | |
| ids = torch.tensor([prompt_tokens + [BLANK] * n_fill], device=DEV) | |
| remaining = set(range(prompt_len, prompt_len + n_fill)) | |
| passes = max(1, int(args.nat_passes)) | |
| def _nat_history(current_ids: torch.Tensor): | |
| keep = current_ids[0] != BLANK | |
| if bool(keep.any()): | |
| return current_ids[:, keep] | |
| return current_ids[:, :max(1, prompt_len)] | |
| def _nat_pick(logits_pos: torch.Tensor, current_ids: torch.Tensor): | |
| logits_pos = logits_pos.clone() | |
| logits_pos[..., BLANK] = -1e9 | |
| logits_pos = _apply_penalties( | |
| logits_pos, | |
| _nat_history(current_ids), | |
| args.penalty_last_n, | |
| args.repetition_penalty, | |
| args.presence_penalty, | |
| args.frequency_penalty, | |
| ) | |
| logits_pos = _suppress_eos(logits_pos, args) | |
| return _sample(logits_pos, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy) | |
| for p in range(passes): | |
| if not remaining: | |
| break | |
| h = _block_stream_forward(core, ids, None, args) if block_stream else core(ids, None) | |
| logits = nat_h(h).float() | |
| logits[..., BLANK] = -1e9 | |
| conf = logits.softmax(-1).amax(-1) | |
| k = max(1, -(-len(remaining) // (passes - p))) | |
| ordered = sorted(remaining, key=lambda q: float(conf[0, q]), reverse=True)[:k] | |
| for pos in ordered: | |
| nxt = _nat_pick(logits[:, pos, :], ids) | |
| ids[0, pos] = int(nxt.reshape(-1)[0]) | |
| remaining.discard(pos) | |
| if remaining: | |
| h = _block_stream_forward(core, ids, None, args) if block_stream else core(ids, None) | |
| logits = nat_h(h).float() | |
| logits[..., BLANK] = -1e9 | |
| for pos in sorted(remaining): | |
| nxt = _nat_pick(logits[:, pos, :], ids) | |
| ids[0, pos] = int(nxt.reshape(-1)[0]) | |
| else: | |
| cached_len = ids.size(1) | |
| block_stream_kv = block_stream and _block_stream_kv_cache_enabled(args) | |
| if block_stream_kv: | |
| h, kvs = _block_stream_forward_cached( | |
| core, | |
| ids, | |
| sat_mask(ids.size(1), structured=use_structured_masks(args)), | |
| None, | |
| cached_len, | |
| args, | |
| ) | |
| elif block_stream: | |
| h = _block_stream_forward(core, ids, sat_mask(ids.size(1), structured=use_structured_masks(args)), args) | |
| kvs = None | |
| else: | |
| h, kvs = core(ids, sat_mask(ids.size(1), structured=use_structured_masks(args)), use_cache=True, total_seq_len=cached_len) | |
| h_buffer = h[:, -SAT_BLOCK:] | |
| added = 0 | |
| stop = False | |
| # Align to a SAT block boundary with AR tokens before block emission. | |
| while ids.size(1) % SAT_BLOCK != 0 and added < args.max_new: | |
| logits = ar_h(h)[:, -1].float() | |
| logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty) | |
| logits = _suppress_eos(logits, args, added < min_new) | |
| nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy) | |
| ids = torch.cat([ids, nxt], 1) | |
| added += 1 | |
| if EOS is not None and not getattr(args, "ignore_eos", False) and int(nxt.item()) == int(EOS): | |
| stop = True | |
| break | |
| if block_stream: | |
| if block_stream_kv: | |
| h, kvs = _block_stream_forward_cached(core, nxt, None, kvs, ids.size(1), args) | |
| cached_len = ids.size(1) | |
| h_buffer = torch.cat([h_buffer, h], dim=1)[:, -SAT_BLOCK:] | |
| else: | |
| h = _block_stream_forward(core, ids, sat_mask(ids.size(1), structured=use_structured_masks(args)), args) | |
| h_buffer = h[:, -SAT_BLOCK:] | |
| else: | |
| h, kvs = core(nxt, None, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1)) | |
| cached_len = ids.size(1) | |
| h_buffer = torch.cat([h_buffer, h], dim=1)[:, -SAT_BLOCK:] | |
| while added < args.max_new and not stop: | |
| logits_all, gate = sat_h(h_buffer) | |
| logits_all = logits_all.float() | |
| if gate is not None: | |
| gate = gate.float() | |
| stride = SAT_BLOCK if (not args.var or gate is None) else (gate.softmax(-1).multinomial(1).item() + 1) | |
| stride = min(int(stride), logits_all.size(1)) | |
| new_tokens = [] | |
| for i in range(int(stride)): | |
| logits = logits_all[:, i].clone() | |
| # BLANK is the SAT/NAT mask-filler token; with this tokenizer it is | |
| # ALSO the EOS id (pad==eos==1), so an unbanned SAT head "ends" on | |
| # every filler prediction while NAT (which bans BLANK) keeps going. | |
| # Ban it here exactly like the NAT path does. | |
| logits[..., BLANK] = -1e9 | |
| logits = _apply_penalties(logits, ids, args.penalty_last_n, args.repetition_penalty, args.presence_penalty, args.frequency_penalty) | |
| logits = _suppress_eos(logits, args, added < min_new) | |
| nxt = _sample(logits, args.temperature, args.top_k, args.top_p, args.min_p, args.greedy) | |
| new_tokens.append(nxt) | |
| ids = torch.cat([ids, nxt], 1) | |
| added += 1 | |
| if EOS is not None and not getattr(args, "ignore_eos", False) and int(nxt.item()) == int(EOS): | |
| stop = True | |
| break | |
| if added >= args.max_new: break | |
| if stop or added >= args.max_new: break | |
| new_ids = torch.cat(new_tokens, dim=1) | |
| if block_stream: | |
| if block_stream_kv: | |
| mask = sat_mask_cached(new_ids.size(1), cached_len, structured=use_structured_masks(args)) | |
| h, kvs = _block_stream_forward_cached(core, new_ids, mask, kvs, ids.size(1), args) | |
| cached_len = ids.size(1) | |
| h_buffer = torch.cat([h_buffer, h], dim=1)[:, -SAT_BLOCK:] | |
| else: | |
| h = _block_stream_forward(core, ids, sat_mask(ids.size(1), structured=use_structured_masks(args)), args) | |
| h_buffer = h[:, -SAT_BLOCK:] | |
| else: | |
| mask = sat_mask_cached(new_ids.size(1), cached_len, structured=use_structured_masks(args)) | |
| h, kvs = core(new_ids, mask, kv_caches=kvs, use_cache=True, total_seq_len=ids.size(1)) | |
| cached_len = ids.size(1) | |
| h_buffer = torch.cat([h_buffer, h], dim=1)[:, -SAT_BLOCK:] | |
| elapsed = time.time() - start | |
| gen_tokens = len(ids[0]) - prompt_len | |
| tok_per_sec = gen_tokens / elapsed if elapsed > 0 else 0 | |
| if (block_stream or resident_dtype) and torch.cuda.is_available(): | |
| peak_alloc_gb = torch.cuda.max_memory_allocated() / 1e9 | |
| peak_reserved_gb = torch.cuda.max_memory_reserved() / 1e9 | |
| label = "block_stream" if block_stream else "resident" | |
| print(f"[infer] {label}_cuda_peak_alloc={peak_alloc_gb:.2f}GB peak_reserved={peak_reserved_gb:.2f}GB") | |
| if block_stream and _moe_expert_stream_enabled(args): | |
| loads, tokens = _moe_expert_stream_stats(core) | |
| print(f"[infer] moe_expert_stream_loads={loads} routed_tokens={tokens}") | |
| all_tokens = ids[0].tolist() | |
| prompt_text = tok.decode(all_tokens[:prompt_len], skip_special_tokens=True) | |
| gen_text = tok.decode(all_tokens[prompt_len:], skip_special_tokens=True) | |
| safe_prompt = _ascii_safe(prompt_text) if plain_output else prompt_text | |
| safe_gen = _ascii_safe(gen_text) if plain_output else gen_text | |
| if plain_output: | |
| print(f"{safe_prompt}{safe_gen}") | |
| print(f"[{elapsed:.2f}s | {gen_tokens} tokens | {tok_per_sec:.1f} tok/s]") | |
| else: | |
| print(f"{Colors.PROMPT}{safe_prompt}{Colors.RESET}{safe_gen}") | |
| print(f"{Colors.INFO}[{elapsed:.2f}s | {gen_tokens} tokens | {tok_per_sec:.1f} tok/s]{Colors.RESET}") | |
| if getattr(args, "claude_friendly", False): | |
| claude_prompt = _ascii_safe(prompt_text) | |
| claude_gen = _ascii_safe(gen_text) | |
| print("[CLAUDE_FRIENDLY_START]") | |
| print(f"[mode={mode_str}]") | |
| print("[prompt_input]") | |
| print(claude_prompt) | |
| print("[completion]") | |
| print(claude_gen) | |
| print("[prompt_plus_completion]") | |
| print(f"{claude_prompt}{claude_gen}") | |
| print(f"[stats] {elapsed:.2f}s | {gen_tokens} tokens | {tok_per_sec:.1f} tok/s") | |
| print("[CLAUDE_FRIENDLY_END]") | |
| # ───────────────────────── CLI ───────────────────────── | |
| # ------------------------- AGILLM4.3 native supervisor ------------------------- | |
| def _agillm43_now_iso(): | |
| import time | |
| return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) | |
| def _agillm43_log_json(log_path, event, **fields): | |
| import json | |
| from pathlib import Path | |
| payload = {"event": event, "at": _agillm43_now_iso()} | |
| payload.update(fields) | |
| line = json.dumps(payload, separators=(",", ":")) | |
| print(line, flush=True) | |
| try: | |
| lp = Path(log_path) | |
| lp.parent.mkdir(parents=True, exist_ok=True) | |
| with lp.open("a", encoding="utf-8") as f: | |
| f.write(line + "\n") | |
| except Exception: | |
| pass | |
| def _agillm43_cmdline(pid): | |
| from pathlib import Path | |
| try: | |
| raw = Path(f"/proc/{int(pid)}/cmdline").read_bytes() | |
| return [x.decode("utf-8", "ignore") for x in raw.split(b"\0") if x] | |
| except Exception: | |
| return [] | |
| def _agillm43_matching_pids(kind): | |
| import os | |
| from pathlib import Path | |
| me = os.getpid() | |
| pids = [] | |
| for proc in Path("/proc").glob("[0-9]*"): | |
| try: | |
| pid = int(proc.name) | |
| except ValueError: | |
| continue | |
| if pid == me: | |
| continue | |
| cmd = _agillm43_cmdline(pid) | |
| if not cmd: | |
| continue | |
| exe = Path(cmd[0]).name.lower() | |
| if "python" not in exe: | |
| continue | |
| joined = " ".join(cmd) | |
| if "agillm41.py" not in joined: | |
| continue | |
| if kind == "train" and " train " in f" {joined} ": | |
| pids.append(pid) | |
| elif kind == "supervise" and " supervise " in f" {joined} ": | |
| pids.append(pid) | |
| return sorted(set(pids)) | |
| def _agillm43_gpu_pids(): | |
| import subprocess | |
| pids = [] | |
| try: | |
| out = subprocess.check_output( | |
| ["nvidia-smi", "--query-compute-apps=pid", "--format=csv,noheader,nounits"], | |
| text=True, | |
| stderr=subprocess.DEVNULL, | |
| timeout=10, | |
| ) | |
| for line in out.splitlines(): | |
| line = line.strip().split(",", 1)[0].strip() | |
| if line.isdigit(): | |
| pids.append(int(line)) | |
| except Exception: | |
| pass | |
| return pids | |
| def _agillm43_latest_step(save_dir): | |
| import json | |
| from pathlib import Path | |
| try: | |
| return int(json.loads((Path(save_dir) / "latest.json").read_text()).get("step", 0)) | |
| except Exception: | |
| return 0 | |
| def _agillm43_kill(pid, sig): | |
| import os | |
| try: | |
| os.kill(int(pid), sig) | |
| return True | |
| except Exception: | |
| return False | |
| def _agillm43_prepare_env(save_dir, side_dir): | |
| import os | |
| from pathlib import Path | |
| env = os.environ.copy() | |
| env.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| env.setdefault("TOKENIZER_ID", "deepseek-ai/DeepSeek-V4-Pro") | |
| env.setdefault("AGILLM_ATTN_BACKEND", "sublinear") | |
| env.pop("PYTORCH_CUDA_ALLOC_CONF", None) | |
| shm = Path("/dev/shm") | |
| if shm.is_dir() and os.access(shm, os.W_OK): | |
| tmp = shm / "agillm_tmp" | |
| tmp.mkdir(parents=True, exist_ok=True) | |
| env.update({"TMPDIR": str(tmp), "TMP": str(tmp), "TEMP": str(tmp)}) | |
| hf_token_path = Path("/root/.cache/huggingface/token") | |
| if hf_token_path.exists(): | |
| token = hf_token_path.read_text(errors="ignore").strip() | |
| if token: | |
| env["HF_TOKEN"] = token | |
| env["HUGGING_FACE_HUB_TOKEN"] = token | |
| def _agillm43_load_secret_file(env_name, paths): | |
| if env.get(env_name, "").strip(): | |
| return True | |
| for raw_path in paths: | |
| try: | |
| p = Path(raw_path) | |
| if p.exists(): | |
| val = p.read_text(errors="ignore").strip() | |
| if val: | |
| env[env_name] = val | |
| return True | |
| except Exception: | |
| pass | |
| return False | |
| have_deepseek = _agillm43_load_secret_file( | |
| "DEEPSEEK_API_KEY", | |
| ( | |
| "/root/.config/agillm/deepseek_api_key", | |
| "/workspace/private/deepseek_api_key", | |
| "/workspace/agillm_private/deepseek_api_key", | |
| ), | |
| ) | |
| have_openrouter = _agillm43_load_secret_file( | |
| "OPENROUTER_API_KEY", | |
| ( | |
| "/root/.config/agillm/openrouter_api_key", | |
| "/workspace/private/openrouter_api_key", | |
| "/workspace/agillm_private/openrouter_api_key", | |
| ), | |
| ) | |
| env.setdefault("AGILLM_MAX_EXAMPLE_TOKENS", "4096") | |
| env.setdefault("AGILLM_MAX_EXAMPLE_CHARS", "32768") | |
| env.setdefault("AGILLM_DATASET_NN_ROUTER", "1") | |
| env.setdefault("AGILLM_DATASET_ROUTER_EXPLORE", "0.08") | |
| env.setdefault("AGILLM_DATASET_ROUTER_MIN_SCORE", "0.12") | |
| env.setdefault("AGILLM_DATASET_ROUTER_SHARPNESS", "2.0") | |
| env.setdefault("AGILLM_DATASET_ROUTER_TARGET_TOKENS", "2048") | |
| if have_deepseek or have_openrouter: | |
| env.setdefault("AGILLM_DATASET_AGENT_ROUTER", "0") | |
| env.setdefault("AGILLM_DATASET_AGENT_PROVIDER", "auto") | |
| Path(save_dir).mkdir(parents=True, exist_ok=True) | |
| for name in ("incoming", "accepted", "rejected"): | |
| (Path(side_dir) / name).mkdir(parents=True, exist_ok=True) | |
| return env | |
| def _agillm43_prune_save_dir(save_dir): | |
| import os | |
| from pathlib import Path | |
| d = Path(save_dir) | |
| for tmp in d.glob("*.tmp"): | |
| try: | |
| tmp.unlink() | |
| except Exception: | |
| pass | |
| ckpts = sorted(d.glob("pretrain_step*.pt"), key=lambda x: x.stat().st_mtime, reverse=True) | |
| for old in ckpts[1:]: | |
| try: | |
| old.unlink() | |
| except Exception: | |
| pass | |
| def _agillm43_latest_checkpoint_path(save_dir): | |
| import glob | |
| import json | |
| import os | |
| from pathlib import Path | |
| save = Path(save_dir) | |
| src = "" | |
| try: | |
| src = json.loads((save / "latest.json").read_text()).get("path", "") | |
| except Exception: | |
| src = "" | |
| if src and Path(src).exists(): | |
| return str(Path(src)) | |
| candidates = sorted(glob.glob(str(save / "pretrain_step*.pt")), key=os.path.getmtime) | |
| return candidates[-1] if candidates else "" | |
| def _agillm43_convert_resume_delta(save_dir, log_path): | |
| import os | |
| import re | |
| from pathlib import Path | |
| import torch | |
| save = Path(save_dir) | |
| shm = Path(os.environ.get("SHM_DIR", "/dev/shm")) | |
| if not (shm.is_dir() and os.access(shm, os.W_OK)): | |
| shm = save | |
| out = shm / "agillm43_resume.delta.pt" | |
| mark = out.parent / ".agillm43_resume.step" | |
| src = _agillm43_latest_checkpoint_path(save) | |
| if not src: | |
| seed = save / "agillm42_tiekv_seed.delta.pt" | |
| _agillm43_log_json(log_path, "native_supervisor_resume_seed", path=str(seed)) | |
| return str(seed) | |
| src_path = Path(src) | |
| m = re.search(r"step0*([0-9]+)", src_path.name) | |
| fstep = m.group(1) if m else "" | |
| try: | |
| st = src_path.stat() | |
| src_meta = { | |
| "path": str(src_path.resolve()), | |
| "name": src_path.name, | |
| "size": int(st.st_size), | |
| "mtime_ns": int(st.st_mtime_ns), | |
| "step": int(fstep) if fstep else None, | |
| } | |
| except Exception: | |
| src_meta = { | |
| "path": str(src_path), | |
| "name": src_path.name, | |
| "step": int(fstep) if fstep else None, | |
| } | |
| def _resume_delta_mark_matches(): | |
| if not (out.exists() and mark.exists()): | |
| return False | |
| try: | |
| payload = json.loads(mark.read_text().strip() or "{}") | |
| except Exception: | |
| # Old marker files only stored the step number. Rebuild once so a | |
| # stale delta from a failed probe cannot replay over a good full ckpt. | |
| return False | |
| if not isinstance(payload, dict): | |
| return False | |
| return all(payload.get(k) == v for k, v in src_meta.items()) | |
| if _resume_delta_mark_matches(): | |
| _agillm43_log_json(log_path, "native_supervisor_resume_delta_current", source=src_meta, path=str(out)) | |
| return str(out) | |
| ck = torch.load(src_path, map_location="cpu", weights_only=False) | |
| tok_keys = ("tokenizer_payload_schema", "tokenizer_id", "tokenizer_json", "tokenizer_bundle", "tokenizer_special", "transformers_version", "tokenizers_version") | |
| tok_payload = {} | |
| sidecar_payload = _read_tokenizer_sidecar(src_path) | |
| tok_payload.update({k: v for k, v in sidecar_payload.items() if k in tok_keys and v is not None}) | |
| tok_payload.update({k: ck.get(k) for k in tok_keys if isinstance(ck, dict) and ck.get(k) is not None}) | |
| if not tok_payload.get("tokenizer_json") or not tok_payload.get("tokenizer_bundle") or not tok_payload.get("tokenizer_special"): | |
| runtime_payload = _tokenizer_payload() | |
| tok_payload = {**runtime_payload, **tok_payload} | |
| tok_payload.setdefault("tokenizer_payload_schema", 2) | |
| src_meta["tokenizer_payload_schema"] = int(tok_payload.get("tokenizer_payload_schema", 2) or 2) | |
| delta = { | |
| "delta": True, | |
| "weights": {k: ck[k] for k in ("core", "ar", "sat", "nat") if k in ck}, | |
| "step": ck.get("step", 0), | |
| "seen_tok": ck.get("seen_tok", 0), | |
| "cfg": ck.get("cfg"), | |
| "source_checkpoint": src_meta, | |
| **tok_payload, | |
| } | |
| tmp = str(out) + ".tmp" | |
| torch.save(delta, tmp) | |
| os.replace(tmp, out) | |
| mark.write_text(json.dumps(src_meta, sort_keys=True)) | |
| try: | |
| Path(str(out) + ".sha256").unlink() | |
| except FileNotFoundError: | |
| pass | |
| _agillm43_log_json(log_path, "native_supervisor_resume_delta_converted", src=str(src_path), source=src_meta, path=str(out), step=int(delta.get("step", 0))) | |
| return str(out) | |
| AGILLM43_PROFILE_CHOICES = ("normal", "ar_repair", "full_ar_repair", "sat_repair", "sat_probe") | |
| def _agillm43_profile_config(profile): | |
| profile = str(profile or "normal").lower() | |
| profiles = { | |
| "normal": { | |
| "ar_prob": "0.60", "sat_prob": "0.25", "nat_prob": "0.15", | |
| "ar_loss_tokens": "512", "sat_loss_tokens": "512", "nat_loss_tokens": "512", | |
| "sat_every": "1", "nat_every": "4", | |
| }, | |
| "ar_repair": { | |
| # Hybrid-safe recovery mode. Keep AR emphasis for text quality, but | |
| # never disable SAT/NAT; AGILLM-4.3 is meant to recover as a hybrid. | |
| "ar_prob": "0.55", "sat_prob": "0.30", "nat_prob": "0.15", | |
| "ar_loss_tokens": "768", "sat_loss_tokens": "768", "nat_loss_tokens": "512", | |
| "sat_every": "1", "nat_every": "4", | |
| }, | |
| "full_ar_repair": { | |
| # Conventional full-stack AR repair for checkpoints whose dblock | |
| # generator has collapsed into high-frequency token predictions. | |
| # This intentionally avoids dblock side updates and SAT/NAT losses. | |
| "full_ar_repair": True, | |
| "batch_size": "2", "block": "768", "steps": "500", | |
| "lr_core": "1e-5", "lr_head": "5e-5", | |
| "save_every_sec": "900", | |
| }, | |
| "sat_repair": { | |
| "ar_prob": "0.45", "sat_prob": "0.40", "nat_prob": "0.15", | |
| "ar_loss_tokens": "512", "sat_loss_tokens": "1024", "nat_loss_tokens": "512", | |
| "sat_every": "1", "nat_every": "4", | |
| }, | |
| "sat_probe": { | |
| "ar_prob": "0.05", "sat_prob": "0.90", "nat_prob": "0.05", | |
| "ar_loss_tokens": "256", "sat_loss_tokens": "2048", "nat_loss_tokens": "256", | |
| "sat_every": "1", "nat_every": "4", | |
| }, | |
| } | |
| if profile not in profiles: | |
| raise ValueError(f"unknown AGILLM4.3 profile {profile!r}; choose one of {', '.join(AGILLM43_PROFILE_CHOICES)}") | |
| cfg = profiles[profile].copy() | |
| cfg["name"] = profile | |
| return cfg | |
| def _agillm43_train_argv(save_dir, side_dir, resume_delta, profile="normal", warmstart_from=None): | |
| import sys | |
| from pathlib import Path | |
| script = str(Path(__file__).resolve()) | |
| incoming = str(Path(side_dir) / "incoming") | |
| accepted = str(Path(side_dir) / "accepted") | |
| rejected = str(Path(side_dir) / "rejected") | |
| prof = _agillm43_profile_config(profile) | |
| if prof.get("full_ar_repair"): | |
| return [ | |
| sys.executable, "-u", script, "train", | |
| "--preset", "agillm4_floor", "--tie_kv", "--resume_delta", resume_delta, | |
| *(["--warmstart_from", str(warmstart_from)] if warmstart_from else []), | |
| "--moe_ffn", "--moe_experts", "2", "--moe_top_k", "1", "--moe_mlp_mult", "4", | |
| "--moe_shared_experts", "1", "--moe_shared_mlp_mult", "2", "--moe_aux_coef", "0.01", "--moe_z_coef", "0.001", | |
| "--tie_weights", "--batch_size", prof.get("batch_size", "2"), "--block", prof.get("block", "768"), | |
| "--steps", prof.get("steps", "500"), "--amp", "--grad_checkpoint", | |
| "--attn_backend", os.environ.get("AGILLM43_ATTN_BACKEND", "sdpa"), | |
| "--sublinear_window", "128", "--sublinear_stride", "128", "--sublinear_max_anchors", "128", "--sublinear_chunk", "128", | |
| "--sublinear_sinks", "4", "--sublinear_recent_anchors", "64", "--no-sublinear_pooled_landmarks", | |
| "--optimizer", "adamw8bit", "--loss_spike_skip", "3.0", "--ar_only", | |
| "--lr_core", prof.get("lr_core", "1e-5"), "--lr_head", prof.get("lr_head", "5e-5"), | |
| "--token_param_ratio", "55", "--val_tokens", "32768", "--val_every_sec", "900", | |
| "--val_source", "json:/workspace/agillm_math_numeracy_synth/train.jsonl", "--data_seed", "-1", | |
| "--save_dir", str(save_dir), "--save_every_sec", prof.get("save_every_sec", "900"), "--heartbeat_every_sec", "120", | |
| "--empty_cache_every_steps", "25", "--delta_every_steps", "0", "--delta_max_keep", "0", "--max_ckpts", "2", | |
| ] | |
| return [ | |
| sys.executable, "-u", script, "train", | |
| "--preset", "agillm4_floor", "--tie_kv", "--resume_delta", resume_delta, | |
| *(["--warmstart_from", str(warmstart_from)] if warmstart_from else []), | |
| "--dblock", "--dblock_blocks", "4", "--dblock_schedule", "loss_balanced", | |
| "--dblock_router", "transformer", "--dblock_router_blend", "0.35", "--dblock_router_ramp_steps", "256", | |
| "--dblock_warmup_steps", "16", "--dblock_sigma_curriculum_steps", "2000", | |
| "--dblock_log_every", "25", "--dblock_objective_mode", "stochastic", | |
| "--dblock_ar_prob", prof["ar_prob"], "--dblock_sat_prob", prof["sat_prob"], "--dblock_nat_prob", prof["nat_prob"], | |
| "--dblock_ar_loss_tokens", prof["ar_loss_tokens"], "--dblock_sat_loss_tokens", prof["sat_loss_tokens"], "--dblock_nat_loss_tokens", prof["nat_loss_tokens"], | |
| "--moe_ffn", "--moe_experts", "2", "--moe_top_k", "1", "--moe_mlp_mult", "4", | |
| "--moe_shared_experts", "1", "--moe_shared_mlp_mult", "2", "--moe_aux_coef", "0.01", "--moe_z_coef", "0.001", | |
| "--tie_weights", "--batch_size", os.environ.get("AGILLM43_BATCH_SIZE", "22"), "--block", os.environ.get("AGILLM43_BLOCK", "1536"), "--amp", "--attn_backend", os.environ.get("AGILLM43_ATTN_BACKEND", "sdpa"), | |
| "--sublinear_window", "128", "--sublinear_stride", "128", "--sublinear_max_anchors", "128", "--sublinear_chunk", "128", | |
| "--sublinear_sinks", "4", "--sublinear_recent_anchors", "64", "--no-sublinear_pooled_landmarks", | |
| "--dblock_checkpoint_stride", "1", "--optimizer", "adamw8bit", | |
| "--loss_spike_skip", "3.0", "--sat_every", prof["sat_every"], "--nat_every", prof["nat_every"], | |
| "--nat_max_tokens", "768", "--nat_mask_ratio", "0.5", "--token_param_ratio", "55", | |
| "--val_tokens", "32768", "--val_every_sec", "3600", "--val_source", "json:/workspace/agillm_math_numeracy_synth/train.jsonl", "--data_seed", "-1", | |
| "--save_dir", str(save_dir), "--save_every_sec", "14400", "--heartbeat_every_sec", "300", | |
| "--empty_cache_every_steps", "0", "--delta_every_steps", "25000", "--delta_max_keep", "1", "--max_ckpts", "1", | |
| "--async_update_dir", incoming, "--async_update_every_steps", "100", "--async_update_alpha", "0.05", | |
| "--async_update_max_per_check", "2", "--async_update_max_age_sec", "86400", | |
| "--async_update_accepted_dir", accepted, "--async_update_rejected_dir", rejected, | |
| ] | |
| def _agillm43_dedupe_trainers(log_path, keep_pid=None): | |
| import signal | |
| pids = _agillm43_matching_pids("train") | |
| if len(pids) <= 1: | |
| return pids | |
| gpu = [p for p in _agillm43_gpu_pids() if p in pids] | |
| keep = int(keep_pid) if keep_pid in pids else (gpu[0] if gpu else pids[0]) | |
| for pid in pids: | |
| if pid == keep: | |
| continue | |
| _agillm43_log_json(log_path, "native_supervisor_kill_duplicate", pid=pid, keep=keep) | |
| _agillm43_kill(pid, signal.SIGTERM) | |
| return [keep] | |
| def supervise_agillm43(args): | |
| import os | |
| import subprocess | |
| import time | |
| from pathlib import Path | |
| log_path = args.log | |
| save_dir = args.save_dir | |
| side_dir = args.side_dir | |
| pause_file = Path(args.pause_file) | |
| script_dir = Path(__file__).resolve().parent | |
| os.chdir(script_dir) | |
| env = _agillm43_prepare_env(save_dir, side_dir) | |
| profile = str(getattr(args, "profile", None) or os.environ.get("AGILLM43_PROFILE", "normal")) | |
| _agillm43_profile_config(profile) | |
| _agillm43_log_json(log_path, "native_supervisor_start", pid=os.getpid(), save_dir=str(save_dir), side_dir=str(side_dir), profile=profile) | |
| while True: | |
| while pause_file.exists(): | |
| _agillm43_log_json(log_path, "native_supervisor_paused", pause=str(pause_file)) | |
| time.sleep(5) | |
| if args.dedupe: | |
| _agillm43_dedupe_trainers(log_path) | |
| live = _agillm43_matching_pids("train") | |
| if live: | |
| if args.once: | |
| _agillm43_log_json(log_path, "native_supervisor_existing_trainer", pids=live) | |
| return 0 | |
| time.sleep(max(1, args.sleep_sec)) | |
| continue | |
| _agillm43_prune_save_dir(save_dir) | |
| resume_src = _agillm43_latest_checkpoint_path(save_dir) | |
| resume_delta = _agillm43_convert_resume_delta(save_dir, log_path) | |
| argv = _agillm43_train_argv(save_dir, side_dir, resume_delta, profile=profile, warmstart_from=resume_src) | |
| _agillm43_log_json(log_path, "native_supervisor_launch", profile=profile, warmstart_from=resume_src, argv=" ".join(argv)) | |
| with open(log_path, "a", encoding="utf-8", buffering=1) as lf: | |
| child = subprocess.Popen(argv, cwd=str(script_dir), env=env, stdout=lf, stderr=subprocess.STDOUT) | |
| if args.once: | |
| _agillm43_log_json(log_path, "native_supervisor_launched_once", pid=child.pid) | |
| return 0 | |
| while child.poll() is None: | |
| if args.dedupe: | |
| _agillm43_dedupe_trainers(log_path, keep_pid=child.pid) | |
| time.sleep(max(1, args.sleep_sec)) | |
| _agillm43_log_json(log_path, "native_supervisor_trainer_exit", pid=child.pid, rc=child.returncode) | |
| time.sleep(max(1, args.sleep_sec)) | |
| def hotpatch_agillm43(args): | |
| import os | |
| import signal | |
| import subprocess | |
| import time | |
| from pathlib import Path | |
| log_path = args.log | |
| save_dir = Path(args.save_dir) | |
| pause_file = Path(args.pause_file) | |
| pause_file.touch() | |
| _agillm43_log_json(log_path, "native_hotpatch_pause", pause=str(pause_file)) | |
| try: | |
| pids = _agillm43_dedupe_trainers(log_path) | |
| pids = _agillm43_matching_pids("train") | |
| if pids: | |
| gpu = [p for p in _agillm43_gpu_pids() if p in pids] | |
| keep = gpu[0] if gpu else pids[0] | |
| before = _agillm43_latest_step(save_dir) | |
| _agillm43_log_json(log_path, "native_hotpatch_flush_requested", pid=keep, before_step=before) | |
| (save_dir / "FLUSH_NOW").touch() | |
| _agillm43_kill(keep, signal.SIGUSR1) | |
| deadline = time.time() + args.wait_flush_sec | |
| while time.time() < deadline: | |
| cur = _agillm43_latest_step(save_dir) | |
| if cur > before: | |
| _agillm43_log_json(log_path, "native_hotpatch_flush_done", latest_step=cur) | |
| break | |
| time.sleep(5) | |
| else: | |
| cur = _agillm43_latest_step(save_dir) | |
| _agillm43_log_json(log_path, "native_hotpatch_flush_timeout", latest_step=cur, before_step=before) | |
| if not args.force: | |
| return 2 | |
| else: | |
| _agillm43_log_json(log_path, "native_hotpatch_no_trainer") | |
| for spid in _agillm43_matching_pids("supervise"): | |
| if spid == os.getpid(): | |
| continue | |
| _agillm43_log_json(log_path, "native_hotpatch_stop_supervisor", pid=spid) | |
| _agillm43_kill(spid, signal.SIGTERM) | |
| if args.kill_tmux: | |
| subprocess.run(["tmux", "kill-session", "-t", args.tmux_session], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
| time.sleep(2) | |
| for pid in _agillm43_matching_pids("train"): | |
| _agillm43_log_json(log_path, "native_hotpatch_stop_trainer", pid=pid) | |
| _agillm43_kill(pid, signal.SIGTERM) | |
| deadline = time.time() + 120 | |
| while time.time() < deadline and _agillm43_matching_pids("train"): | |
| time.sleep(2) | |
| for pid in _agillm43_matching_pids("train"): | |
| _agillm43_log_json(log_path, "native_hotpatch_kill_stubborn", pid=pid) | |
| _agillm43_kill(pid, signal.SIGKILL) | |
| pause_file.unlink(missing_ok=True) | |
| cmd = [ | |
| "python3", "-u", str(Path(__file__).resolve()), "supervise", | |
| "--save_dir", str(save_dir), "--side_dir", args.side_dir, "--log", log_path, | |
| "--pause_file", str(pause_file), "--sleep_sec", str(args.sleep_sec), | |
| "--profile", str(args.profile), | |
| ] | |
| if args.tmux: | |
| import shlex | |
| quoted = " ".join(shlex.quote(part) for part in cmd) | |
| subprocess.run(["tmux", "new-session", "-d", "-s", args.tmux_session, quoted], check=False) | |
| if not _agillm43_matching_pids("supervise"): | |
| with open(args.nohup_log, "a", encoding="utf-8") as lf: | |
| subprocess.Popen(cmd, cwd=str(Path(__file__).resolve().parent), stdout=lf, stderr=subprocess.STDOUT, start_new_session=True) | |
| _agillm43_log_json(log_path, "native_hotpatch_start_supervisor_nohup_fallback", log=args.nohup_log) | |
| else: | |
| _agillm43_log_json(log_path, "native_hotpatch_start_supervisor_tmux", session=args.tmux_session) | |
| else: | |
| with open(args.nohup_log, "a", encoding="utf-8") as lf: | |
| subprocess.Popen(cmd, cwd=str(Path(__file__).resolve().parent), stdout=lf, stderr=subprocess.STDOUT, start_new_session=True) | |
| _agillm43_log_json(log_path, "native_hotpatch_start_supervisor_nohup", log=args.nohup_log) | |
| deadline = time.time() + args.wait_start_sec | |
| while time.time() < deadline: | |
| live = _agillm43_matching_pids("train") | |
| if len(live) == 1: | |
| _agillm43_log_json(log_path, "native_hotpatch_restart_done", pid=live[0], latest_step=_agillm43_latest_step(save_dir)) | |
| return 0 | |
| if len(live) > 1: | |
| _agillm43_dedupe_trainers(log_path) | |
| time.sleep(3) | |
| _agillm43_log_json(log_path, "native_hotpatch_restart_timeout", trainer_count=len(_agillm43_matching_pids("train"))) | |
| return 3 | |
| finally: | |
| try: | |
| pause_file.unlink() | |
| except FileNotFoundError: | |
| pass | |
| def main(): | |
| ap = argparse.ArgumentParser(description="AGILLM Expansion Ratio Testing") | |
| sub = ap.add_subparsers(dest="cmd", required=True) | |
| tr = sub.add_parser("train") | |
| tr.add_argument("--preset", choices=PRESETS.keys(), default="large") | |
| tr.add_argument("--rank", type=int) | |
| tr.add_argument("--block", type=int, default=DEFAULT_BLOCK) | |
| tr.add_argument("--batch_size", type=int, default=DEFAULT_BATCH) | |
| tr.add_argument("--source", default=DEFAULT_PRETRAIN_SOURCES) | |
| tr.add_argument("--target_tokens", type=int) | |
| tr.add_argument("--token_param_ratio", type=float, default=0.0, | |
| help="If --target_tokens is omitted, train to this tokens:param ratio. AGILLM-4 presets default to 100.") | |
| tr.add_argument("--steps", type=int) | |
| tr.add_argument("--amp", action="store_true") | |
| tr.add_argument("--compile", action="store_true", help="Use torch.compile for speedup") | |
| tr.add_argument("--attn_backend", choices=["manual", "sdpa", "sublinear"], default=DEFAULT_ATTN_BACKEND, | |
| help="AGILLM-4 attention backend. sublinear uses local-window plus landmark candidates.") | |
| tr.add_argument("--grad_checkpoint", action="store_true", | |
| help="Recompute transformer blocks during backward to trade speed for longer context.") | |
| tr.add_argument("--sublinear_window", type=int, default=DEFAULT_SUBLINEAR_WINDOW, | |
| help="For --attn_backend sublinear, attend to this many local tokens on each side.") | |
| tr.add_argument("--sublinear_stride", type=int, default=DEFAULT_SUBLINEAR_STRIDE, | |
| help="For --attn_backend sublinear, use every Nth token as a landmark candidate.") | |
| tr.add_argument("--sublinear_max_anchors", type=int, default=DEFAULT_SUBLINEAR_MAX_ANCHORS, | |
| help="For --attn_backend sublinear, cap landmark candidates per query chunk.") | |
| tr.add_argument("--sublinear_chunk", type=int, default=DEFAULT_SUBLINEAR_CHUNK, | |
| help="For --attn_backend sublinear, query chunk size controlling peak gather memory.") | |
| tr.add_argument("--sublinear_sinks", type=int, default=DEFAULT_SUBLINEAR_SINKS, | |
| help="For sublinear attention, always include this many first-token attention sinks.") | |
| tr.add_argument("--sublinear_recent_anchors", type=int, default=DEFAULT_SUBLINEAR_RECENT_ANCHORS, | |
| help="For capped sublinear anchors, reserve this many anchors for the recent tail; -1 uses half.") | |
| tr.add_argument("--sublinear_pooled_landmarks", action=argparse.BooleanOptionalAction, | |
| default=DEFAULT_SUBLINEAR_POOLED_LANDMARKS, | |
| help="Use stride-segment pooled K/V summaries for sublinear landmark anchors.") | |
| tr.add_argument("--no_structured_masks", action="store_true", | |
| help="Disable structured causal/SAT masks for sublinear attention and fall back to dense masks.") | |
| tr.add_argument("--anchor_memory", action="store_true", | |
| help="Enable anchor-memory long-context augmentation (one AnchorMemoryLayer at mid-stack).") | |
| tr.add_argument("--anchor_stride", type=int, default=DEFAULT_ANCHOR_STRIDE, | |
| help="Token span compressed into one anchor (default 256).") | |
| tr.add_argument("--anchor_max", type=int, default=DEFAULT_ANCHOR_MAX, | |
| help="Max anchors retained in the rolling memory bank.") | |
| tr.add_argument("--anchor_position", type=int, default=DEFAULT_ANCHOR_POSITION, | |
| help="Block index after which to insert anchor memory (-1 = stack middle).") | |
| tr.add_argument("--kv_buffer", action="store_true", | |
| help="Use preallocated KV buffer instead of torch.cat-based cache growth.") | |
| tr.add_argument("--optimizer", choices=["adamw", "adamw8bit", "paged_adamw8bit", "powerstep"], default="adamw", | |
| help="Optimizer backend. 8-bit options reduce VRAM on 24GB production runs. 'powerstep' (arXiv:2605.10335) uses a single momentum buffer; in a faithful dblock-step benchmark it converged below Adam, but needs its own LR (~1e-3) and an int8/paged buffer to fit at B=6.") | |
| tr.add_argument("--powerstep_beta", type=float, default=0.1, | |
| help="PowerStep signed-power exponent beta in (0,1); 0.1 is the paper's recommended value.") | |
| tr.add_argument("--powerstep_momentum", type=float, default=0.9, | |
| help="PowerStep heavy-ball momentum coefficient gamma.") | |
| tr.add_argument("--powerstep_int8", action="store_true", | |
| help="PowerStep: store the momentum buffer as blockwise int8 in VRAM (~1/4 VRAM; needs bitsandbytes).") | |
| tr.add_argument("--powerstep_paged", action="store_true", | |
| help="PowerStep: keep the momentum buffer in pinned CPU RAM (~0 persistent VRAM, spends RAM+PCIe).") | |
| tr.add_argument("--save_every_sec", type=int, default=DEFAULT_SAVE_SEC) | |
| tr.add_argument("--disk_free_floor_gb", type=float, default=12.0, | |
| help="In-file disk auto-prune: when free space drops below this, escalate pruning of transient artifacts and old checkpoints. 0 disables the floor (routine keep-count pruning still runs).") | |
| tr.add_argument("--val_tokens", type=int, default=0, | |
| help="Held-out validation set size in tokens (sampled once from --val_seed stream at startup). 0 disables validation.") | |
| tr.add_argument("--val_every_sec", type=int, default=3600, | |
| help="Run held-out validation every N seconds (requires --val_tokens > 0).") | |
| tr.add_argument("--val_seed", type=int, default=1337, | |
| help="Shuffle seed for the held-out validation stream (distinct from the training data seed).") | |
| tr.add_argument("--val_source", default="", | |
| help="Optional validation-only dataset source. When set, bypasses hot_config so health probes are comparable across restarts.") | |
| tr.add_argument("--data_seed", type=int, default=42, | |
| help="Training stream shuffle seed. -1 derives a per-restart seed from the resume step so restarts do not re-train identical early data.") | |
| tr.add_argument("--heartbeat_every_sec", type=int, default=300, | |
| help="Print lightweight trainer heartbeat/status lines every N seconds; 0 disables.") | |
| tr.add_argument("--empty_cache_every_steps", type=int, default=0, | |
| help="Call torch.cuda.empty_cache() every N train steps; useful for VRAM-first runs where lower reserved VRAM matters more than speed.") | |
| tr.add_argument("--profile_steps", type=int, default=0, | |
| help="Profile the first N DBlock training steps with in-process CUDA timers; 0 disables.") | |
| tr.add_argument("--profile_log_every", type=int, default=25, | |
| help="Print averaged profiler timings every N profiled steps.") | |
| tr.add_argument("--delta_every_steps", type=int, default=DEFAULT_DELTA_STEPS, help="Weight-only delta save every N steps (0=off)") | |
| tr.add_argument("--delta_max_keep", type=int, default=DEFAULT_MAX_DELTAS, help="Max delta checkpoints to keep") | |
| tr.add_argument("--resume_delta", type=str, help="Resume from a delta (weight-only, no optimizer state)") | |
| tr.add_argument("--async_update_dir", default="", | |
| help="Optional incoming directory for verified DBlock side updates. Empty disables async side updates.") | |
| tr.add_argument("--async_update_every_steps", type=int, default=0, | |
| help="Poll --async_update_dir every N master steps. Side workers never block master progress.") | |
| tr.add_argument("--async_update_alpha", type=float, default=1.0, | |
| help="Blend factor for accepted side updates: 1.0 copies side block weights; lower values lerp into live weights.") | |
| tr.add_argument("--async_update_max_per_check", type=int, default=1, | |
| help="Maximum side-update files to apply per poll.") | |
| tr.add_argument("--async_update_max_age_sec", type=float, default=0.0, | |
| help="Reject incoming side updates older than this many seconds. 0 disables age rejection.") | |
| tr.add_argument("--async_update_accepted_dir", default="", | |
| help="Directory for applied side-update files. Defaults to a sibling accepted/ directory.") | |
| tr.add_argument("--async_update_rejected_dir", default="", | |
| help="Directory for rejected side-update files. Defaults to a sibling rejected/ directory.") | |
| tr.add_argument("--save_dir", default=str(CKDIR)) | |
| tr.add_argument("--resume", type=str) | |
| tr.add_argument("--x2", action="store_true") | |
| tr.add_argument("--warmstart_from", type=str) | |
| tr.add_argument("--fresh", action="store_true") | |
| tr.add_argument("--max_ckpts", type=int, default=None) | |
| tr.add_argument("--chilla_max_double", action="store_true") | |
| tr.add_argument("--tie_weights", action="store_true") | |
| tr.add_argument("--ar_only", action="store_true") | |
| tr.add_argument("--agillm3_compat", action="store_true", | |
| help="Legacy AGILLM3/3.5 checkpoint mode. Use TOKENIZER_ID=deepseek-ai/DeepSeek-V3.2 or the agillm35.py shim for the old tokenizer contract.") | |
| tr.add_argument("--no_nat_head", action="store_true", | |
| help="Do not instantiate/save a NAT head. Keeps AGILLM3 AR+SAT checkpoint schema and reduces params/RAM.") | |
| tr.add_argument("--sat_every", type=int, default=1, | |
| help="Train SAT every N steps. Default 1 keeps AR+SAT every step.") | |
| tr.add_argument("--nat_every", type=int, default=1, | |
| help="Train NAT every N steps with a CTC objective. Default 1 keeps AR+SAT+NAT every step.") | |
| tr.add_argument("--nat_loss_weight", type=float, default=1.0) | |
| tr.add_argument("--nat_expand", type=int, default=2, | |
| help="Repeat tokens this many times for the NAT CTC input length.") | |
| tr.add_argument("--nat_max_tokens", type=int, default=0, | |
| help="Optional cap for NAT target tokens per batch; 0 uses the whole block.") | |
| tr.add_argument("--dblock_nat_embed_noise_mode", choices=["off", "visible", "mask_plus_noise"], default="mask_plus_noise", | |
| help="NAT embedding noise mode. off=standard BLANK masking. visible=add noise to clean embeddings. mask_plus_noise=BLANK mask + noise on masked positions.") | |
| tr.add_argument("--dblock_nat_embed_noise_scale", type=float, default=1.0, | |
| help="Scale factor for embedding noise in NAT hybrid modes.") | |
| tr.add_argument("--nat_mask_ratio", type=float, default=0.5, | |
| help="Fraction of positions masked to BLANK for the NAT mask-predict (CMLM) objective.") | |
| tr.add_argument("--tie_kv", action=argparse.BooleanOptionalAction, default=False, | |
| help="Q-K=V: tie Key & Value into one projection (~50%% KV cache, -33%% qkv params). Trained-in only; not loadable into a 3-proj checkpoint.") | |
| tr.add_argument("--moe_ffn", action=argparse.BooleanOptionalAction, default=DEFAULT_MOE_FFN, | |
| help="Use Mixture-of-Experts feed-forward layers inside the transformer blocks.") | |
| tr.add_argument("--moe_experts", type=int, default=DEFAULT_MOE_EXPERTS, | |
| help="Number of FFN experts per transformer block when --moe_ffn is enabled.") | |
| tr.add_argument("--moe_top_k", type=int, default=DEFAULT_MOE_TOP_K, | |
| help="Router top-k experts per token when --moe_ffn is enabled.") | |
| tr.add_argument("--moe_mlp_mult", type=int, default=DEFAULT_MOE_MLP_MULT, | |
| help="Expert hidden-size multiplier; 4 preserves dense FFN checkpoint shape for seeding.") | |
| tr.add_argument("--moe_shared_experts", type=int, default=0, | |
| help="Always-on shared experts added to the routed output (DeepSeek/ST-MoE style). 0 disables. Output is zero-init so it merges into an existing checkpoint as a no-op then learns to contribute.") | |
| tr.add_argument("--moe_shared_mlp_mult", type=int, default=0, | |
| help="Hidden-size multiplier for shared experts (0 = same as --moe_mlp_mult). Use a smaller value (1-2) to limit added VRAM.") | |
| tr.add_argument("--moe_aux_coef", type=float, default=0.0, | |
| help="Weight for the MoE load-balance (Switch) aux loss. 0 disables (legacy). ~0.01 keeps both experts utilised under top-1 routing. Checkpoint-safe (router recomputed outside the checkpoint).") | |
| tr.add_argument("--moe_z_coef", type=float, default=0.0, | |
| help="Weight for the MoE router z-loss (router-logit magnitude regularizer). 0 disables. ~0.001 stabilizes routing.") | |
| tr.add_argument("--loss_spike_skip", type=float, default=0.0, | |
| help="Skip the optimizer step when the mean raw CE exceeds this multiple of its EMA (dblock path). 0 disables. ~3.0 drops pathological noisy-batch spikes.") | |
| tr.add_argument("--dblock", action="store_true", help="DiffusionBlocks block-wise denoising training (low VRAM).") | |
| tr.add_argument("--auto_dblock_search", action="store_true", help="Auto-search block configs") | |
| tr.add_argument("--dblock_blocks", type=int, default=4, help="Partition layers into this many DiffusionBlocks blocks.") | |
| tr.add_argument("--dblock_schedule", choices=["random", "roundrobin", "loss_balanced"], default="loss_balanced", | |
| help="How --dblock chooses the next layer block. loss_balanced focuses blocks whose EMA loss is highest after warmup.") | |
| tr.add_argument("--dblock_router", choices=["heuristic", "transformer"], default="heuristic", | |
| help="Optional learned sequence-Transformer scheduler for DBlock layer-band selection; coverage guards still enforce fairness.") | |
| tr.add_argument("--dblock_router_hidden", type=int, default=64, | |
| help="Hidden width for the context/history sequence-Transformer DBlock router.") | |
| tr.add_argument("--dblock_router_heads", type=int, default=4, | |
| help="Attention heads for the context/history sequence-Transformer DBlock router.") | |
| tr.add_argument("--dblock_router_layers", type=int, default=2, | |
| help="Transformer encoder layers for the context/history sequence-Transformer DBlock router.") | |
| tr.add_argument("--dblock_router_lr", type=float, default=0.002, | |
| help="Online learning rate for the context/history sequence-Transformer DBlock router.") | |
| tr.add_argument("--dblock_router_blend", type=float, default=0.35, | |
| help="Max blend of learned-router score into heuristic DBlock score after ramp-up.") | |
| tr.add_argument("--dblock_router_ramp_steps", type=int, default=256, | |
| help="DBlock steps over which the learned router ramps from 0 to --dblock_router_blend.") | |
| tr.add_argument("--dblock_warmup_steps", type=int, default=16, | |
| help="Initial DBlock steps spent covering every block before loss-balanced scheduling.") | |
| tr.add_argument("--dblock_explore", type=float, default=0.08, | |
| help="Exploration rate for loss-balanced DBlock scheduling.") | |
| tr.add_argument("--dblock_max_stale_steps", type=int, default=64, | |
| help="Force the stalest DBlock after this many unselected DBlock steps; 0 disables.") | |
| tr.add_argument("--dblock_max_count_skew", type=float, default=1.35, | |
| help="Force least-trained DBlock when max/min sampled block counts exceed this ratio; <=1 disables.") | |
| tr.add_argument("--dblock_stale_bonus", type=float, default=0.35, | |
| help="Loss-score bonus for stale DBlocks before the hard stale guard triggers.") | |
| tr.add_argument("--dblock_undertrain_bonus", type=float, default=0.25, | |
| help="Loss-score bonus for under-sampled DBlocks before the hard count-skew guard triggers.") | |
| tr.add_argument("--dblock_log_every", type=int, default=25, | |
| help="Print DBlock block/loss/VRAM diagnostics every N DBlock steps; 0 disables.") | |
| tr.add_argument("--dblock_sublayer_mode", choices=["off", "full", "attn_only", "ffn_only", "split_alt", "cycle"], default="off", | |
| help="Experimental dormant knob: train only transformer sublayers inside selected DiffusionBlocks. off/full keeps normal Block.forward; attn_only trains LN1+attention residual; ffn_only trains LN2+FFN/MoE residual; split_alt alternates attention/FFN by step; cycle rotates full/FFN/attention.") | |
| tr.add_argument("--dblock_checkpoint_stride", type=int, default=1, | |
| help="With --grad_checkpoint in --dblock mode, checkpoint one layer every N selected block layers; 1=all layers, 2=alternate, 0=off.") | |
| tr.add_argument("--dblock_checkpoint_skip_tail", type=int, default=0, | |
| help="Experimental DBlock speed knob: do not checkpoint this many final layers in the selected block, reducing backward recompute at higher VRAM cost.") | |
| tr.add_argument("--dblock_activation_offload", action="store_true", | |
| help="Experimental DBlock speed knob: for non-checkpointed block layers, offload saved backward tensors to CPU RAM instead of recomputing.") | |
| tr.add_argument("--dblock_activation_offload_min_mb", type=float, default=1.0, | |
| help="Minimum CUDA tensor size in MB to offload under --dblock_activation_offload.") | |
| tr.add_argument("--dblock_sigma_curriculum_steps", type=int, default=2000, | |
| help="Warm sigma ranges from easy to full span over this many DBlock steps; 0 disables.") | |
| tr.add_argument("--dblock_edm_wmax", type=float, default=5.0, | |
| help="Cap for EDM loss weighting in DBlock mode.") | |
| tr.add_argument("--dblock_ar_weight", type=float, default=1.0) | |
| tr.add_argument("--dblock_sat_weight", type=float, default=1.0) | |
| tr.add_argument("--dblock_nat_weight", type=float, default=1.0) | |
| tr.add_argument("--dblock_objective_mode", choices=["periodic", "stochastic"], default="periodic", | |
| help="DBlock objective scheduler. stochastic samples one objective per step to reduce redundant AR/SAT/NAT forwards.") | |
| tr.add_argument("--dblock_ar_prob", type=float, default=0.80, help="Stochastic DBlock probability for AR objective.") | |
| tr.add_argument("--dblock_sat_prob", type=float, default=0.10, help="Stochastic DBlock probability for SAT objective.") | |
| tr.add_argument("--dblock_nat_prob", type=float, default=0.10, help="Stochastic DBlock probability for NAT objective.") | |
| tr.add_argument("--dblock_ar_loss_tokens", type=int, default=0, | |
| help="If >0, uniformly sample this many AR target positions per DBlock step for stochastic token-level CE.") | |
| tr.add_argument("--dblock_sat_loss_tokens", type=int, default=0, | |
| help="If >0, uniformly sample this many SAT target positions per DBlock step.") | |
| tr.add_argument("--dblock_nat_loss_tokens", type=int, default=0, | |
| help="If >0, uniformly sample this many NAT target positions per DBlock step.") | |
| tr.add_argument("--reinit_nat", action="store_true", | |
| help="Reinitialize NAT head weights after load (use once when switching to mask-predict).") | |
| tr.add_argument("--seed_nat_from_ar", action="store_true", | |
| help="Seed the NAT head from the trained AR head ('father') after load instead of random init.") | |
| tr.add_argument("--freeze_core", action="store_true") | |
| tr.add_argument("--unfreeze_ln", action="store_true") | |
| tr.add_argument("--train_emb", action="store_true") | |
| tr.add_argument("--lr_core", type=float, default=LR_CORE) | |
| tr.add_argument("--lr_head", type=float, default=LR_HEAD) | |
| tr.add_argument("--chat", action="store_true") | |
| tr.add_argument("--chat_messages_key", default="messages") | |
| tr.add_argument("--dataset_field_text", default="text") | |
| tr.add_argument("--sft_add_generation_prompt", action="store_true") | |
| tr.add_argument("--auto_grow", action="store_true") | |
| tr.add_argument("--grow_plan", default="576,640,768,896,1024,1122") | |
| tr.add_argument("--grow_every_steps", type=int, default=50000) | |
| tr.add_argument("--after_sft_source", default="") | |
| tr.add_argument("--after_sft_steps", type=int, default=0) | |
| tr.add_argument("--after_sft_chat", action="store_true") | |
| tr.add_argument("--after_sft_chat_messages_key", default="messages") | |
| tr.add_argument("--after_sft_dataset_field_text", default="text") | |
| tr.add_argument("--after_sft_add_generation_prompt", type=bool, default=None) | |
| tr.add_argument("--after_sft_block", type=int, default=0) | |
| tr.add_argument("--after_sft_freeze_core", action="store_true") | |
| tr.add_argument("--after_sft_unfreeze_ln", action="store_true") | |
| tr.add_argument("--after_sft_train_emb", action="store_true") | |
| tr.add_argument("--after_sft_lr_core", type=float, default=0.0) | |
| tr.add_argument("--after_sft_lr_head", type=float, default=0.0) | |
| inf = sub.add_parser("infer") | |
| inf.add_argument("--mode", choices=["ar", "sat", "nat"], required=True) | |
| inf.add_argument("--sampler", choices=["ar", "euler"], default="ar", help="ar=KV decode; euler=DiffusionBlocks EDM Euler sampler.") | |
| inf.add_argument("--euler_steps", type=int, default=0, help="Euler ODE steps (0=2x dblock_blocks).") | |
| inf.add_argument("--euler_start_sigma", type=float, default=0.0, help="Euler start noise (0=sigma_max; lower=stronger context conditioning).") | |
| inf.add_argument("--dblock_blocks", type=int, default=4, help="Number of DiffusionBlocks for the Euler sampler.") | |
| inf.add_argument("--ckpt", required=True) | |
| inf.add_argument("--prompt", required=True) | |
| inf.add_argument("--max_new", type=int, default=120) | |
| inf.add_argument("--min_new", type=int, default=0, help="Minimum generated tokens before EOS can stop decoding. SAT enforces at least one block.") | |
| inf.add_argument("--temperature", type=float, default=None) | |
| inf.add_argument("--greedy", action="store_true") | |
| inf.add_argument("--top_k", type=int, default=None) | |
| inf.add_argument("--top_p", type=float, default=0.9) | |
| inf.add_argument("--min_p", type=float, default=0.0) | |
| inf.add_argument("--repetition_penalty", type=float, default=None) | |
| inf.add_argument("--presence_penalty", type=float, default=None) | |
| inf.add_argument("--frequency_penalty", type=float, default=None) | |
| inf.add_argument("--penalty_last_n", type=int, default=None) | |
| inf.add_argument("--var", action="store_true", default=None) | |
| inf.add_argument("--no-var", dest="var", action="store_false") | |
| inf.add_argument("--claude-friendly", action="store_true", help="Also print an artifact-free prompt/completion block for downstream JSON consumers") | |
| inf.add_argument("--plain-output", "--no-color", dest="plain_output", action="store_true", help="Use plain ASCII/no ANSI output for redirected inference logs") | |
| inf.add_argument("--attn_backend", choices=["manual", "sdpa", "sublinear"], default=DEFAULT_ATTN_BACKEND) | |
| inf.add_argument("--sublinear_window", type=int, default=DEFAULT_SUBLINEAR_WINDOW) | |
| inf.add_argument("--sublinear_stride", type=int, default=DEFAULT_SUBLINEAR_STRIDE) | |
| inf.add_argument("--sublinear_max_anchors", type=int, default=DEFAULT_SUBLINEAR_MAX_ANCHORS) | |
| inf.add_argument("--sublinear_chunk", type=int, default=DEFAULT_SUBLINEAR_CHUNK) | |
| inf.add_argument("--sublinear_sinks", type=int, default=DEFAULT_SUBLINEAR_SINKS) | |
| inf.add_argument("--sublinear_recent_anchors", type=int, default=DEFAULT_SUBLINEAR_RECENT_ANCHORS) | |
| inf.add_argument("--sublinear_pooled_landmarks", action=argparse.BooleanOptionalAction, | |
| default=DEFAULT_SUBLINEAR_POOLED_LANDMARKS) | |
| inf.add_argument("--no_structured_masks", action="store_true") | |
| inf.add_argument("--nat_expand", type=int, default=2) | |
| inf.add_argument("--nat_passes", type=int, default=1) | |
| inf.add_argument("--ignore_eos", action="store_true", | |
| help="Never stop on (or sample) EOS: suppress its logit and emit exactly max_new tokens. For base-model / SAT-head testing.") | |
| inf.add_argument("--infer_dtype", choices=["fp32", "fp16", "bf16"], default="fp32", | |
| help="Resident inference dtype. fp16/bf16 load on CPU, convert, then move the model to CUDA to avoid fp32 VRAM spikes.") | |
| inf.add_argument("--block_stream", action="store_true", | |
| help="VRAM-saving inference: keep heads/embeddings resident and page Encoder blocks through the compute device.") | |
| inf.add_argument("--block_stream_page_layers", type=int, default=1, | |
| help="Layers per resident page for --block_stream. 1=lowest VRAM; 0=use --dblock_blocks pages.") | |
| inf.add_argument("--block_stream_empty_cache", action=argparse.BooleanOptionalAction, default=True, | |
| help="Call torch.cuda.empty_cache() after each streamed page unload.") | |
| inf.add_argument("--block_stream_dtype", choices=["fp32", "fp16", "bf16"], default="fp32", | |
| help="Weight/activation dtype for --block_stream. fp16 halves CPU->GPU transfer bytes on CUDA-capable cards.") | |
| inf.add_argument("--block_stream_kv_cache", action=argparse.BooleanOptionalAction, default=True, | |
| help="Use KV cache for AR/SAT --block_stream decode instead of recomputing the full prefix each token.") | |
| inf.add_argument("--block_stream_kv_device", choices=["cuda", "cpu"], default="cuda", | |
| help="Where --block_stream keeps KV cache tensors. cuda is faster; cpu minimizes resident VRAM.") | |
| inf.add_argument("--block_stream_cache_pages", action=argparse.BooleanOptionalAction, default=None, | |
| help="Auto by default: keep streamed layer pages resident when VRAM allows. Use --no-block_stream_cache_pages for strict low-VRAM streaming.") | |
| inf.add_argument("--moe_expert_stream", action="store_true", | |
| help="With --block_stream, keep routed MoE experts on CPU and page only selected experts through the compute device.") | |
| inf.add_argument("--moe_expert_stream_empty_cache", action=argparse.BooleanOptionalAction, default=True, | |
| help="Call torch.cuda.empty_cache() after unloading each streamed MoE expert.") | |
| sup = sub.add_parser("supervise", help="Native AGILLM4.3 trainer supervisor") | |
| sup.add_argument("--save_dir", default="/workspace/agillm4_4090_ckpts") | |
| sup.add_argument("--side_dir", default="/workspace/agillm41_side_updates") | |
| sup.add_argument("--log", default="/workspace/agillm41_master_train.log") | |
| sup.add_argument("--pause_file", default="/tmp/agillm43_master_watchdog.pause") | |
| sup.add_argument("--sleep_sec", type=int, default=15) | |
| sup.add_argument("--dedupe", action=argparse.BooleanOptionalAction, default=True) | |
| sup.add_argument("--once", action="store_true") | |
| sup.add_argument("--profile", choices=AGILLM43_PROFILE_CHOICES, default="normal", | |
| help="Training launch profile: normal, ar_repair, full_ar_repair, sat_repair, or sat_probe.") | |
| hp = sub.add_parser("hotpatch", help="Flush checkpoint and restart under native AGILLM4.3 supervisor") | |
| hp.add_argument("--save_dir", default="/workspace/agillm4_4090_ckpts") | |
| hp.add_argument("--side_dir", default="/workspace/agillm41_side_updates") | |
| hp.add_argument("--log", default="/workspace/agillm41_master_train.log") | |
| hp.add_argument("--pause_file", default="/tmp/agillm43_master_watchdog.pause") | |
| hp.add_argument("--wait_flush_sec", type=int, default=900) | |
| hp.add_argument("--wait_start_sec", type=int, default=300) | |
| hp.add_argument("--sleep_sec", type=int, default=15) | |
| hp.add_argument("--profile", choices=AGILLM43_PROFILE_CHOICES, default="normal", | |
| help="Training launch profile used by the restarted supervisor.") | |
| hp.add_argument("--force", action="store_true") | |
| hp.add_argument("--tmux", action=argparse.BooleanOptionalAction, default=True) | |
| hp.add_argument("--tmux_session", default="master_wd") | |
| hp.add_argument("--kill_tmux", action=argparse.BooleanOptionalAction, default=True) | |
| hp.add_argument("--nohup_log", default="/workspace/agillm41_native_supervisor.nohup") | |
| st = sub.add_parser("status", help="Read-only training status") | |
| st.add_argument("--json", dest="json_output", action="store_true") | |
| st.add_argument("--log", type=str, default=str(STATUS_DEFAULT_LOG)) | |
| st.add_argument("--save_dir", type=str, default=str(STATUS_DEFAULT_SAVE_DIR)) | |
| args = ap.parse_args() | |
| if args.cmd == "train": train(args) | |
| elif args.cmd == "infer": infer(args) | |
| elif args.cmd == "supervise": raise SystemExit(supervise_agillm43(args)) | |
| elif args.cmd == "hotpatch": raise SystemExit(hotpatch_agillm43(args)) | |
| elif args.cmd == "status": raise SystemExit(_emit_status(Path(args.log), Path(args.save_dir), args.json_output)) | |
| else: raise SystemExit(f"unknown command: {args.cmd}") | |
| if __name__ == "__main__": | |
| main() | |
| # ===== END nB300_agillm4.py ===== | |
Xet Storage Details
- Size:
- 338 kB
- Xet hash:
- f17ba06f6f77ad9fc6976f282532775887076ecc8e44c109c7ca76733966d2f9
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.