Buckets:
| #!/usr/bin/env python3 | |
| """ | |
| AGILLM-4.3 DiffusionBlocks - INDEPENDENT (one-block-resident) training. | |
| Implements the memory / parallelism property from the DiffusionBlocks paper | |
| (arXiv:2506.14202), "Comparison with activation checkpointing": | |
| * Activation checkpointing reduces ONLY activation memory. | |
| * DiffusionBlocks reduces ALL memory components (params + grads + optimizer | |
| + activations) by ~B, because each of the B blocks is trained INDEPENDENTLY | |
| - only ONE block's params/grads/optimizer/activations are resident at a time. | |
| * Each block is embarrassingly parallel: B invocations on B machines, zero | |
| communication. They are then composed into ONE unified inference model. | |
| Why this is valid for AGILLM-4.3: in agillm41._dblock_step every block reads the | |
| EDM-noised embedding DIRECTLY (h = ci*zt, zt = emb + sigma*noise) and runs only | |
| its OWN contiguous layer group - blocks are PARALLEL sigma-band denoisers that | |
| share emb / ln / head, NOT a sequential stack. So block_b's forward | |
| (emb -> layers[b] -> ln) is mathematically identical whether layers[b] live in a | |
| full L-layer Encoder or in a standalone Encoder with L/B layers. That identity is | |
| exactly what makes independent training + compose() exact. | |
| This module REUSES the live trainer's real classes/objective (agillm41.Encoder, | |
| ARHead, _block_sigmas, _edm_pre, _edm_w, fused_ce, _run_block). It does NOT import | |
| or modify the training loop, and never needs to run on the training GPU | |
| (CPU-only is fine; the math is identical). | |
| Shared-stem policy: emb + final ln (+ AR head if untied) are a small SHARED stem. | |
| Independent block training FREEZES the stem (loaded from a snapshot), so each | |
| block-trainer only holds grads/optimizer for its L/B layers, and compose() is | |
| EXACT (every block saw the identical frozen stem). | |
| CLI: | |
| python agillm43_diffusionblocks_independent.py selftest | |
| python agillm43_diffusionblocks_independent.py mem-report [--d --layers --heads --rank --vocab] | |
| python agillm43_diffusionblocks_independent.py make-stem --d.. --layers.. --B.. --out stem.pt | |
| python agillm43_diffusionblocks_independent.py stem-from-ckpt --ckpt full.pt --out stem.pt | |
| python agillm43_diffusionblocks_independent.py train-block --init-ckpt full.pt --B 4 --block 0 --out b0.pt [--steps..] | |
| python agillm43_diffusionblocks_independent.py compose --blocks b0.pt b1.pt b2.pt b3.pt --out full.pt | |
| python agillm43_diffusionblocks_independent.py compose-into-ckpt --base-ckpt full.pt --blocks b0.pt b1.pt b2.pt b3.pt --out dbi_full.pt | |
| """ | |
| import os, sys, math, json, argparse, time | |
| import torch | |
| import torch.nn as nn | |
| import agillm41 as A | |
| import nB300_agillm4 as M | |
| # ---- faithful reuse of the live trainer's building blocks ---- | |
| Encoder = A.Encoder | |
| ARHead = A.ARHead | |
| block_sigmas = A._block_sigmas | |
| edm_pre = A._edm_pre | |
| edm_w = A._edm_w | |
| fused_ce = A.fused_ce | |
| run_block = A._run_block | |
| # The live AGILLM-4.3 (~1.22B): d_model=1280, 28 layers, 20 heads, low-rank 160. | |
| LIVE_CFG = {"d": 1280, "layers": 28, "heads": 20, "rank": 160} | |
| # ----------------------------- helpers -------------------------------------- | |
| def even_split(L, B): | |
| if int(L) % int(B) != 0: | |
| raise ValueError( | |
| f"even-split demo: L={L} not divisible by B={B}. The live " | |
| f"_dblock_block_layers handles a remainder; that is out of scope here.") | |
| return int(L) // int(B) | |
| def sub_cfg(cfg, lb): | |
| c = dict(cfg); c["layers"] = int(lb); return c | |
| def build_block_model(cfg, B, device="cpu", tie_weights=True, attn_backend="sdpa"): | |
| """One DiffusionBlocks block as a standalone model: emb + L/B layers + ln + AR head.""" | |
| lb = even_split(int(cfg["layers"]), B) | |
| core = Encoder(sub_cfg(cfg, lb), tie_weights=tie_weights, attn_backend=attn_backend).to(device) | |
| ew = core.emb.weight if tie_weights else None | |
| ar = ARHead(int(cfg["d"]), tie_weights=tie_weights, embedding_weight=ew).to(device) | |
| return core, ar, lb | |
| def stem_param_list(core, ar, tie_weights): | |
| ps = list(core.emb.parameters()) + list(core.ln.parameters()) | |
| if not tie_weights: | |
| ps += list(ar.parameters()) | |
| return ps | |
| def set_freeze(params, frozen=True): | |
| for p in params: | |
| p.requires_grad = (not frozen) | |
| def unique_trainable(*modules): | |
| seen, out = set(), [] | |
| for m in modules: | |
| for p in m.parameters(): | |
| if p.requires_grad and id(p) not in seen: | |
| seen.add(id(p)); out.append(p) | |
| return out | |
| def _causal(T, device): | |
| m = M.causal_mask(T, structured=False) | |
| if torch.is_tensor(m): | |
| m = m.to(device) | |
| return m | |
| def _denoise_hidden(emb_mod, ln_mod, layers, ids, sig, noise, attn_args=None): | |
| """Deterministic EDM-preconditioned forward through `layers` (mirrors _dblock_step). | |
| `layers` is an iterable of Block modules. `noise` is supplied for reproducibility.""" | |
| device = ids.device | |
| T = ids.size(1) | |
| cs, co, ci = edm_pre(sig) | |
| causal = _causal(T, device) | |
| emb = emb_mod(ids) | |
| zt = emb + sig[:, None, None] * noise | |
| h = ci * zt | |
| for blk in layers: | |
| h = run_block(blk, h, causal, False, None, "off") | |
| return ln_mod(cs * zt + co * h) | |
| def dblock_ar_loss(core, ar, ids, lo, hi, wmax=5.0): | |
| """AR branch of agillm41._dblock_step, restricted to this block's sigma band.""" | |
| device = ids.device | |
| u = torch.rand(ids.size(0), device=device) | |
| sig = torch.exp(math.log(lo) + u * (math.log(hi) - math.log(lo))) # log-uniform in band | |
| w = edm_w(sig, wmax) | |
| noise = torch.randn(core.emb(ids).shape, device=device) | |
| Dn = _denoise_hidden(core.emb, core.ln, core.blocks, ids, sig, noise) | |
| raw = fused_ce(Dn[:, :-1], ar.proj.weight, ids[:, 1:]) | |
| return w * raw, float(raw.detach()) | |
| def opt_state_bytes(opt): | |
| tot = 0 | |
| for st in opt.state.values(): | |
| for v in st.values(): | |
| if torch.is_tensor(v): | |
| tot += v.numel() * v.element_size() | |
| return tot | |
| def trainable_bytes(params): | |
| return sum(p.numel() * p.element_size() for p in params if p.requires_grad) | |
| # ----------------------------- stem I/O ------------------------------------- | |
| def make_stem(cfg, B, out_path, device="cpu", tie_weights=True, attn_backend="sdpa", seed=0): | |
| """Create + save a shared stem (emb + ln [+ AR head if untied]). | |
| In real use you would instead snapshot the stem from an existing checkpoint.""" | |
| torch.manual_seed(seed) | |
| core, ar, lb = build_block_model(cfg, B, device, tie_weights, attn_backend) | |
| payload = { | |
| "kind": "diffusionblock_stem", | |
| "cfg": dict(cfg), "tie_weights": bool(tie_weights), | |
| "emb": core.emb.state_dict(), "ln": core.ln.state_dict(), | |
| "ar": ar.state_dict(), | |
| } | |
| if out_path: | |
| torch.save(payload, out_path) | |
| return payload | |
| def load_stem_into(core, ar, stem, tie_weights): | |
| if isinstance(stem, str): | |
| stem = torch.load(stem, map_location="cpu") | |
| core.emb.load_state_dict(stem["emb"]) | |
| core.ln.load_state_dict(stem["ln"]) | |
| if not tie_weights: | |
| ar.load_state_dict(stem["ar"]) | |
| def _load_ckpt(path, device="cpu"): | |
| return torch.load(path, map_location=device, weights_only=False) | |
| def _cfg_from_ckpt(ck): | |
| if "cfg" in ck: | |
| return dict(ck["cfg"]) | |
| if "seed_meta" in ck: | |
| cfg = ck["seed_meta"].get("v4_preset") or ck["seed_meta"].get("v3_preset") | |
| if cfg: | |
| return dict(cfg) | |
| raise KeyError("checkpoint has neither cfg nor seed_meta preset") | |
| def _set_vocab_from_core_state(core_state): | |
| if "emb.weight" in core_state: | |
| A.VOCAB = int(core_state["emb.weight"].shape[0]) | |
| return int(A.VOCAB) | |
| def _stem_payload_from_ckpt(ck): | |
| core_state = ck["core"] | |
| _set_vocab_from_core_state(core_state) | |
| tie_weights = bool(ck.get("tie_weights", False)) | |
| payload = { | |
| "kind": "diffusionblock_stem", | |
| "cfg": _cfg_from_ckpt(ck), | |
| "tie_weights": tie_weights, | |
| "emb": {"weight": core_state["emb.weight"].detach().cpu()}, | |
| "ln": { | |
| "weight": core_state["ln.weight"].detach().cpu(), | |
| "bias": core_state["ln.bias"].detach().cpu(), | |
| }, | |
| "ar": {k: v.detach().cpu() for k, v in ck.get("ar", {}).items()}, | |
| "source_step": int(ck.get("step", 0) or 0), | |
| "source_seen_tok": int(ck.get("seen_tok", 0) or 0), | |
| "tokenizer_id": ck.get("tokenizer_id"), | |
| } | |
| return payload | |
| def save_stem_from_ckpt(ckpt_path, out_path): | |
| payload = _stem_payload_from_ckpt(_load_ckpt(ckpt_path, device="cpu")) | |
| torch.save(payload, out_path) | |
| return payload | |
| def _block_layer_indices(total_layers, B, block_index): | |
| lb = even_split(int(total_layers), int(B)) | |
| start = int(block_index) * lb | |
| return list(range(start, start + lb)) | |
| def load_ckpt_block_into(core, ar, ck, cfg, B, block_index, tie_weights): | |
| core_state = ck["core"] | |
| _set_vocab_from_core_state(core_state) | |
| core.emb.load_state_dict({"weight": core_state["emb.weight"]}) | |
| core.ln.load_state_dict({"weight": core_state["ln.weight"], "bias": core_state["ln.bias"]}) | |
| if not tie_weights: | |
| ar.load_state_dict(ck.get("ar", {})) | |
| local_state = {} | |
| for local_i, global_i in enumerate(_block_layer_indices(int(cfg["layers"]), B, block_index)): | |
| src_prefix = f"blocks.{global_i}." | |
| dst_prefix = f"{local_i}." | |
| for key, value in core_state.items(): | |
| if key.startswith(src_prefix): | |
| local_state[dst_prefix + key[len(src_prefix):]] = value | |
| missing, unexpected = core.blocks.load_state_dict(local_state, strict=False) | |
| if missing or unexpected: | |
| raise RuntimeError(f"block slice load mismatch: missing={missing[:8]} unexpected={unexpected[:8]}") | |
| def compose_into_checkpoint(base_ckpt, block_ckpts, out_path, device="cpu"): | |
| ck = _load_ckpt(base_ckpt, device=device) | |
| cfg = _cfg_from_ckpt(ck) | |
| core_state = ck["core"] | |
| payloads = [torch.load(p, map_location=device, weights_only=False) if isinstance(p, str) else p for p in block_ckpts] | |
| metas = [d["meta"] for d in payloads] | |
| B = int(metas[0]["B"]) | |
| lb = even_split(int(cfg["layers"]), B) | |
| idxs = sorted(int(m["block_index"]) for m in metas) | |
| if idxs != list(range(B)): | |
| raise ValueError(f"compose-into-ckpt needs exactly blocks 0..{B-1}, got {idxs}") | |
| for m in metas: | |
| if int(m["B"]) != B: | |
| raise ValueError("compose-into-ckpt: mismatched B across block checkpoints") | |
| by_bi = {int(m["block_index"]): d for m, d in zip(metas, payloads)} | |
| b0 = by_bi[0] | |
| core_state["emb.weight"] = b0["emb"]["weight"].detach().cpu() | |
| core_state["ln.weight"] = b0["ln"]["weight"].detach().cpu() | |
| core_state["ln.bias"] = b0["ln"]["bias"].detach().cpu() | |
| if not bool(ck.get("tie_weights", False)) and "ar" in ck: | |
| ck["ar"] = {k: v.detach().cpu() for k, v in b0.get("ar", {}).items()} | |
| for bi in range(B): | |
| blk_sd = by_bi[bi]["blocks"] | |
| for local_i in range(lb): | |
| src_prefix = f"{local_i}." | |
| global_i = bi * lb + local_i | |
| dst_prefix = f"blocks.{global_i}." | |
| for key, value in blk_sd.items(): | |
| if key.startswith(src_prefix): | |
| core_state[dst_prefix + key[len(src_prefix):]] = value.detach().cpu() | |
| ck["core"] = core_state | |
| ck["diffusionblocks_independent"] = { | |
| "kind": "diffusionblock_composed_checkpoint", | |
| "B": B, | |
| "layers_per_block": lb, | |
| "source_base_ckpt": str(base_ckpt), | |
| "block_count": len(block_ckpts), | |
| "composed_at_unix": time.time(), | |
| } | |
| torch.save(ck, out_path) | |
| return ck["diffusionblocks_independent"] | |
| # ----------------------------- train one block ------------------------------ | |
| def random_batch(batch, seqlen, vocab, device): | |
| return torch.randint(0, int(vocab), (int(batch), int(seqlen)), device=device) | |
| def train_block(cfg, B, block_index, *, steps=60, batch=4, seqlen=32, lr=3e-4, | |
| tie_weights=True, device="cpu", attn_backend="sdpa", | |
| stem=None, init_ckpt=None, freeze_stem=True, out_path=None, log_every=0, | |
| data_fn=None, seed=0): | |
| """Train exactly ONE block independently. Embarrassingly parallel: run B of | |
| these (one per block_index) on B machines, no communication.""" | |
| torch.manual_seed(1000 + seed + block_index) | |
| init_payload = _load_ckpt(init_ckpt, device="cpu") if init_ckpt is not None else None | |
| if init_payload is not None: | |
| _set_vocab_from_core_state(init_payload["core"]) | |
| sig = block_sigmas(B) | |
| lo, hi = sorted([sig[block_index], sig[block_index + 1]]) | |
| core, ar, lb = build_block_model(cfg, B, device, tie_weights, attn_backend) | |
| if init_payload is not None: | |
| load_ckpt_block_into(core, ar, init_payload, cfg, B, block_index, tie_weights) | |
| elif stem is not None: | |
| load_stem_into(core, ar, stem, tie_weights) | |
| if freeze_stem: | |
| set_freeze(stem_param_list(core, ar, tie_weights), True) | |
| core.train() | |
| train_params = unique_trainable(core, ar) | |
| opt = torch.optim.AdamW(train_params, lr=lr) | |
| if data_fn is None: | |
| data_fn = lambda: random_batch(batch, seqlen, A.VOCAB, device) | |
| losses = [] | |
| for step in range(int(steps)): | |
| ids = data_fn() | |
| opt.zero_grad(set_to_none=True) | |
| loss, raw = dblock_ar_loss(core, ar, ids, lo, hi) | |
| loss.backward() | |
| nn.utils.clip_grad_norm_(train_params, 1.0) | |
| opt.step() | |
| losses.append(raw) | |
| if log_every and (step % log_every == 0 or step == steps - 1): | |
| print(f"[block {block_index}/{B}] step {step:4d} sigma[{lo:.3f},{hi:.3f}] ar_ce={raw:.4f}", flush=True) | |
| info = { | |
| "block_index": int(block_index), "B": int(B), "layers_per_block": int(lb), | |
| "sigma_lo": float(lo), "sigma_hi": float(hi), | |
| "loss_first": float(losses[0]) if losses else None, | |
| "loss_last": float(losses[-1]) if losses else None, | |
| "trainable_param_bytes": int(trainable_bytes(train_params)), | |
| "optimizer_state_bytes": int(opt_state_bytes(opt)), | |
| "n_trainable_params": int(sum(p.numel() for p in train_params)), | |
| } | |
| if out_path: | |
| torch.save({ | |
| "kind": "diffusionblock_independent", | |
| "meta": {"block_index": int(block_index), "B": int(B), | |
| "layers_per_block": int(lb), "base_cfg": dict(cfg), | |
| "tie_weights": bool(tie_weights), "attn_backend": attn_backend, | |
| "sigma_lo": float(lo), "sigma_hi": float(hi)}, | |
| "blocks": core.blocks.state_dict(), | |
| "emb": core.emb.state_dict(), "ln": core.ln.state_dict(), | |
| "ar": ar.state_dict(), | |
| }, out_path) | |
| info["out_path"] = out_path | |
| return info, (core, ar) | |
| # ----------------------------- compose -------------------------------------- | |
| def compose(block_ckpts, out_path=None, device="cpu"): | |
| """Reassemble B independently-trained blocks into ONE unified L-layer Encoder.""" | |
| payloads = [torch.load(p, map_location=device) if isinstance(p, str) else p for p in block_ckpts] | |
| metas = [d["meta"] for d in payloads] | |
| B = metas[0]["B"]; cfg = dict(metas[0]["base_cfg"]); tie = metas[0]["tie_weights"] | |
| ab = metas[0].get("attn_backend", "sdpa"); lb = even_split(int(cfg["layers"]), B) | |
| idxs = sorted(m["block_index"] for m in metas) | |
| if idxs != list(range(B)): | |
| raise ValueError(f"compose needs exactly blocks 0..{B-1}, got {idxs}") | |
| for m in metas: | |
| if m["B"] != B or dict(m["base_cfg"]) != cfg: | |
| raise ValueError("compose: mismatched B / base_cfg across block checkpoints") | |
| full = Encoder(cfg, tie_weights=tie, attn_backend=ab).to(device) | |
| ew = full.emb.weight if tie else None | |
| ar = ARHead(int(cfg["d"]), tie_weights=tie, embedding_weight=ew).to(device) | |
| by_bi = {m["block_index"]: d for m, d in zip(metas, payloads)} | |
| # shared stem: identical across blocks (all trained frozen to the same stem) -> take block 0 | |
| b0 = by_bi[0] | |
| full.emb.load_state_dict(b0["emb"]); full.ln.load_state_dict(b0["ln"]) | |
| if not tie: | |
| ar.load_state_dict(b0["ar"]) | |
| # slot each block's L/B layers into the full stack | |
| for bi in range(B): | |
| blk_sd = by_bi[bi]["blocks"] | |
| for li in range(lb): | |
| pref = f"{li}." | |
| sub = {k[len(pref):]: v for k, v in blk_sd.items() if k.startswith(pref)} | |
| full.blocks[bi * lb + li].load_state_dict(sub) | |
| if out_path: | |
| torch.save({"kind": "diffusionblock_composed", "cfg": cfg, "tie_weights": tie, | |
| "attn_backend": ab, "B": B, "core": full.state_dict(), | |
| "ar": ar.state_dict()}, out_path) | |
| return full, ar, cfg, B, lb | |
| # ----------------------------- memory report -------------------------------- | |
| def _measure_per_layer_params(cfg): | |
| """Build a 1-layer Encoder with a tiny vocab (vocab-independent layer count).""" | |
| saved = A.VOCAB | |
| try: | |
| A.VOCAB = 8 | |
| enc = Encoder(sub_cfg(cfg, 1), tie_weights=False, attn_backend="sdpa") | |
| p_layer = sum(p.numel() for p in enc.blocks.parameters()) | |
| p_ln = sum(p.numel() for p in enc.ln.parameters()) | |
| del enc | |
| finally: | |
| A.VOCAB = saved | |
| return int(p_layer), int(p_ln) | |
| def mem_report(cfg=None, vocab=None, Bs=(1, 2, 4, 7, 14, 28), bytes_per=4): | |
| cfg = dict(cfg or LIVE_CFG) | |
| vocab = int(vocab or A.VOCAB) | |
| d, L = int(cfg["d"]), int(cfg["layers"]) | |
| p_layer, p_ln = _measure_per_layer_params(cfg) | |
| p_emb = vocab * d # tied: AR head shares this | |
| p_stem = p_emb + p_ln # shared, frozen during independent training | |
| p_full = p_stem + L * p_layer | |
| GB = lambda n: n * bytes_per / 1e9 | |
| print(f"# DiffusionBlocks memory model cfg={cfg} vocab={vocab} (fp32, {bytes_per}B/elem)") | |
| print(f"# per-layer params = {p_layer:,} shared stem (emb+ln, tied head) = {p_stem:,}") | |
| print(f"# FULL model params = {p_full:,} ({GB(p_full):.2f} GB params)") | |
| print(f"# Monolith TRAIN mem ~ 4*P_full (param+grad+2*Adam) = {GB(4*p_full):.2f} GB (+ activations A*L)") | |
| print() | |
| hdr = ("B", "L/B", "trainable P", "4*train (param+grad+adam)", "frozen stem", "resident vs full 4P", "speedup") | |
| print("{:>3} {:>4} {:>14} {:>26} {:>13} {:>20} {:>8}".format(*hdr)) | |
| for B in Bs: | |
| if L % B: | |
| continue | |
| lb = L // B | |
| p_train = lb * p_layer | |
| train_mem = 4 * p_train # param + grad + 2 Adam moments (trainable only) | |
| stem_mem = p_stem # frozen: params only, no grad/opt | |
| resident_4p = train_mem + stem_mem # vs monolith 4*p_full | |
| factor = (4 * p_full) / resident_4p | |
| print("{:>3} {:>4} {:>14,} {:>23.2f} GB {:>10.2f} GB {:>16.2f} GB {:>7.2f}x".format( | |
| B, lb, p_train, GB(train_mem), GB(stem_mem), GB(resident_4p), factor)) | |
| print() | |
| print("# 'speedup' = monolith 4*P_full / one-block-resident (4*trainable + frozen stem).") | |
| print("# The 4P (param+grad+Adam) term scales ~1/B; the frozen shared stem is the floor") | |
| print("# (itself offloadable/shardable). Activations scale 1/B too and combine with") | |
| print("# --grad_checkpoint (already in agillm41) for the paper's 4/3-time, 1/B-memory point.") | |
| # ----------------------------- self test ------------------------------------ | |
| def selftest(): | |
| torch.manual_seed(0) | |
| dev = "cpu" | |
| cfg = {"d": 64, "layers": 8, "heads": 4, "rank": 16} | |
| B = 4 | |
| saved_vocab = A.VOCAB | |
| ok = True | |
| try: | |
| A.VOCAB = 256 # tiny vocab -> fast/small selftest | |
| print(f"[selftest] cfg={cfg} B={B} vocab={A.VOCAB}") | |
| # 1) shared stem | |
| stem = make_stem(cfg, B, None, device=dev, tie_weights=True, seed=7) | |
| # 2) train each block INDEPENDENTLY (frozen stem) | |
| blocks, infos = [], [] | |
| for bi in range(B): | |
| info, _ = train_block(cfg, B, bi, steps=40, batch=4, seqlen=24, lr=5e-3, | |
| tie_weights=True, device=dev, stem=stem, | |
| freeze_stem=True, out_path=f"/tmp/_dbi_b{bi}.pt", seed=3) | |
| infos.append(info); blocks.append(f"/tmp/_dbi_b{bi}.pt") | |
| dec = info["loss_first"] - info["loss_last"] | |
| print(f" block {bi}: ar_ce {info['loss_first']:.4f} -> {info['loss_last']:.4f} " | |
| f"(drop {dec:+.4f}), trainable={info['n_trainable_params']:,} " | |
| f"opt_state={info['optimizer_state_bytes']/1e6:.2f}MB") | |
| ok &= (info["loss_last"] < info["loss_first"]) # independent training reduces its band loss | |
| # 3) memory: one block's trainable+opt vs a MONOLITHIC full model | |
| full_core = Encoder(cfg, tie_weights=True, attn_backend="sdpa").to(dev) | |
| full_ar = ARHead(cfg["d"], tie_weights=True, embedding_weight=full_core.emb.weight).to(dev) | |
| full_params = unique_trainable(full_core, full_ar) | |
| full_opt = torch.optim.AdamW(full_params, lr=1e-3) | |
| # one adam step so optimizer state exists | |
| ids = random_batch(4, 24, A.VOCAB, dev) | |
| sigs = block_sigmas(B) | |
| l, _ = dblock_ar_loss(full_core, full_ar, ids, sigs[0], sigs[-1]); l.backward(); full_opt.step() | |
| full_opt_mb = opt_state_bytes(full_opt) / 1e6 | |
| one_opt_mb = infos[0]["optimizer_state_bytes"] / 1e6 | |
| ratio = full_opt_mb / max(one_opt_mb, 1e-9) | |
| print(f"[selftest] optimizer-state full={full_opt_mb:.2f}MB one-block={one_opt_mb:.2f}MB ratio={ratio:.2f}x (target ~{B}x on layer params)") | |
| # full opt state includes the (large, tied) emb head too; on the LAYER portion the ratio ~ B. | |
| ok &= (one_opt_mb < full_opt_mb) | |
| # 4) compose -> unified model | |
| full, ar, ccfg, cB, lb = compose(blocks, out_path="/tmp/_dbi_full.pt", device=dev) | |
| full.eval() | |
| print(f"[selftest] composed full Encoder: layers={ccfg['layers']} from B={cB} x lb={lb}") | |
| # 5) EXACT round-trip: composed full's block-slice forward == standalone block forward | |
| max_diff = 0.0 | |
| for bi in range(B): | |
| torch.manual_seed(500 + bi) | |
| ids = random_batch(2, 20, A.VOCAB, dev) | |
| sgl = block_sigmas(B); lo, hi = sgl[bi], sgl[bi + 1] | |
| sig = torch.full((ids.size(0),), float((lo * hi) ** 0.5), device=dev) | |
| noise = torch.randn(full.emb(ids).shape, device=dev) | |
| # standalone block (reload its ckpt) | |
| d = torch.load(blocks[bi], map_location=dev) | |
| sc = Encoder(sub_cfg(cfg, lb), tie_weights=True, attn_backend="sdpa").to(dev) | |
| sc.emb.load_state_dict(d["emb"]); sc.ln.load_state_dict(d["ln"]); sc.blocks.load_state_dict(d["blocks"]) | |
| sc.eval() | |
| Dn_standalone = _denoise_hidden(sc.emb, sc.ln, sc.blocks, ids, sig, noise) | |
| # composed full, using only this block's layer slice | |
| sl = [full.blocks[bi * lb + j] for j in range(lb)] | |
| Dn_composed = _denoise_hidden(full.emb, full.ln, sl, ids, sig, noise) | |
| diff = float((Dn_standalone - Dn_composed).abs().max()) | |
| max_diff = max(max_diff, diff) | |
| print(f"[selftest] compose round-trip max|Δhidden| = {max_diff:.2e} (want < 1e-4)") | |
| ok &= (max_diff < 1e-4) | |
| print() | |
| print("=" * 70) | |
| print(f"[selftest] {'PASS' if ok else 'FAIL'}") | |
| print("=" * 70) | |
| finally: | |
| A.VOCAB = saved_vocab | |
| return 0 if ok else 1 | |
| # ----------------------------- CLI ------------------------------------------ | |
| def _cfg_from_args(a): | |
| return {"d": a.d, "layers": a.layers, "heads": a.heads, "rank": a.rank} | |
| def main(): | |
| ap = argparse.ArgumentParser(description="AGILLM-4.3 DiffusionBlocks independent training") | |
| sub = ap.add_subparsers(dest="cmd", required=True) | |
| s = sub.add_parser("selftest") | |
| m = sub.add_parser("mem-report") | |
| for name, dv in (("d", 1280), ("layers", 28), ("heads", 20), ("rank", 160)): | |
| m.add_argument(f"--{name}", type=int, default=dv) | |
| m.add_argument("--vocab", type=int, default=0) | |
| m.add_argument("--ckpt", default=None, help="Read cfg/vocab from an AGILLM checkpoint") | |
| sc = sub.add_parser("stem-from-ckpt") | |
| sc.add_argument("--ckpt", required=True) | |
| sc.add_argument("--out", required=True) | |
| mk = sub.add_parser("make-stem") | |
| for name, dv in (("d", 1280), ("layers", 28), ("heads", 20), ("rank", 160)): | |
| mk.add_argument(f"--{name}", type=int, default=dv) | |
| mk.add_argument("--B", type=int, required=True) | |
| mk.add_argument("--out", required=True) | |
| mk.add_argument("--no-tie", action="store_true") | |
| tb = sub.add_parser("train-block") | |
| for name, dv in (("d", 1280), ("layers", 28), ("heads", 20), ("rank", 160)): | |
| tb.add_argument(f"--{name}", type=int, default=dv) | |
| tb.add_argument("--B", type=int, required=True) | |
| tb.add_argument("--block", type=int, required=True) | |
| tb.add_argument("--stem", default=None) | |
| tb.add_argument("--init-ckpt", default=None, help="Initialize stem and this block slice from a full AGILLM checkpoint") | |
| tb.add_argument("--out", required=True) | |
| tb.add_argument("--steps", type=int, default=200) | |
| tb.add_argument("--batch", type=int, default=4) | |
| tb.add_argument("--seqlen", type=int, default=64) | |
| tb.add_argument("--lr", type=float, default=3e-4) | |
| tb.add_argument("--no-tie", action="store_true") | |
| tb.add_argument("--no-freeze-stem", action="store_true") | |
| tb.add_argument("--device", default="cpu") | |
| tb.add_argument("--log-every", type=int, default=20) | |
| cp = sub.add_parser("compose") | |
| cp.add_argument("--blocks", nargs="+", required=True) | |
| cp.add_argument("--out", required=True) | |
| cp.add_argument("--device", default="cpu") | |
| cic = sub.add_parser("compose-into-ckpt") | |
| cic.add_argument("--base-ckpt", required=True) | |
| cic.add_argument("--blocks", nargs="+", required=True) | |
| cic.add_argument("--out", required=True) | |
| cic.add_argument("--device", default="cpu") | |
| a = ap.parse_args() | |
| if a.cmd == "selftest": | |
| sys.exit(selftest()) | |
| if a.cmd == "mem-report": | |
| if a.ckpt: | |
| ck = _load_ckpt(a.ckpt, device="cpu") | |
| mem_report(_cfg_from_ckpt(ck), vocab=(a.vocab or _set_vocab_from_core_state(ck["core"]))) | |
| else: | |
| mem_report(_cfg_from_args(a), vocab=(a.vocab or None)) | |
| return | |
| if a.cmd == "stem-from-ckpt": | |
| payload = save_stem_from_ckpt(a.ckpt, a.out) | |
| print(f"[stem-from-ckpt] saved stem -> {a.out} step={payload.get('source_step')}"); return | |
| if a.cmd == "make-stem": | |
| info = make_stem(_cfg_from_args(a), a.B, a.out, tie_weights=not a.no_tie) | |
| print(f"[make-stem] saved stem -> {a.out}"); return | |
| if a.cmd == "train-block": | |
| cfg = _cfg_from_args(a) | |
| tie_weights = not a.no_tie | |
| if a.init_ckpt: | |
| ck = _load_ckpt(a.init_ckpt, device="cpu") | |
| cfg = _cfg_from_ckpt(ck) | |
| tie_weights = bool(ck.get("tie_weights", tie_weights)) | |
| _set_vocab_from_core_state(ck["core"]) | |
| info, _ = train_block(cfg, a.B, a.block, steps=a.steps, batch=a.batch, | |
| seqlen=a.seqlen, lr=a.lr, tie_weights=tie_weights, | |
| device=a.device, stem=a.stem, init_ckpt=a.init_ckpt, | |
| freeze_stem=not a.no_freeze_stem, | |
| out_path=a.out, log_every=a.log_every) | |
| print(json.dumps(info, indent=2)); return | |
| if a.cmd == "compose": | |
| full, ar, cfg, B, lb = compose(a.blocks, out_path=a.out, device=a.device) | |
| print(f"[compose] {B} blocks x {lb} layers -> full {cfg['layers']}-layer model saved {a.out}"); return | |
| if a.cmd == "compose-into-ckpt": | |
| meta = compose_into_checkpoint(a.base_ckpt, a.blocks, a.out, device=a.device) | |
| print(f"[compose-into-ckpt] B={meta['B']} x {meta['layers_per_block']} layers -> {a.out}"); return | |
| if __name__ == "__main__": | |
| main() | |
Xet Storage Details
- Size:
- 27.9 kB
- Xet hash:
- 1739855aaa3708e2800c535f0fc9d5082cdafafc61bf106f627168158299b193
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.