OpenTransformer/agillm42-checkpoints / code /agillm43_diffusionblocks_independent.py
OpenTransformer's picture
download
raw
27.9 kB
#!/usr/bin/env python3
"""
AGILLM-4.3 DiffusionBlocks - INDEPENDENT (one-block-resident) training.
Implements the memory / parallelism property from the DiffusionBlocks paper
(arXiv:2506.14202), "Comparison with activation checkpointing":
* Activation checkpointing reduces ONLY activation memory.
* DiffusionBlocks reduces ALL memory components (params + grads + optimizer
+ activations) by ~B, because each of the B blocks is trained INDEPENDENTLY
- only ONE block's params/grads/optimizer/activations are resident at a time.
* Each block is embarrassingly parallel: B invocations on B machines, zero
communication. They are then composed into ONE unified inference model.
Why this is valid for AGILLM-4.3: in agillm41._dblock_step every block reads the
EDM-noised embedding DIRECTLY (h = ci*zt, zt = emb + sigma*noise) and runs only
its OWN contiguous layer group - blocks are PARALLEL sigma-band denoisers that
share emb / ln / head, NOT a sequential stack. So block_b's forward
(emb -> layers[b] -> ln) is mathematically identical whether layers[b] live in a
full L-layer Encoder or in a standalone Encoder with L/B layers. That identity is
exactly what makes independent training + compose() exact.
This module REUSES the live trainer's real classes/objective (agillm41.Encoder,
ARHead, _block_sigmas, _edm_pre, _edm_w, fused_ce, _run_block). It does NOT import
or modify the training loop, and never needs to run on the training GPU
(CPU-only is fine; the math is identical).
Shared-stem policy: emb + final ln (+ AR head if untied) are a small SHARED stem.
Independent block training FREEZES the stem (loaded from a snapshot), so each
block-trainer only holds grads/optimizer for its L/B layers, and compose() is
EXACT (every block saw the identical frozen stem).
CLI:
python agillm43_diffusionblocks_independent.py selftest
python agillm43_diffusionblocks_independent.py mem-report [--d --layers --heads --rank --vocab]
python agillm43_diffusionblocks_independent.py make-stem --d.. --layers.. --B.. --out stem.pt
python agillm43_diffusionblocks_independent.py stem-from-ckpt --ckpt full.pt --out stem.pt
python agillm43_diffusionblocks_independent.py train-block --init-ckpt full.pt --B 4 --block 0 --out b0.pt [--steps..]
python agillm43_diffusionblocks_independent.py compose --blocks b0.pt b1.pt b2.pt b3.pt --out full.pt
python agillm43_diffusionblocks_independent.py compose-into-ckpt --base-ckpt full.pt --blocks b0.pt b1.pt b2.pt b3.pt --out dbi_full.pt
"""
import os, sys, math, json, argparse, time
import torch
import torch.nn as nn
import agillm41 as A
import nB300_agillm4 as M
# ---- faithful reuse of the live trainer's building blocks ----
Encoder = A.Encoder
ARHead = A.ARHead
block_sigmas = A._block_sigmas
edm_pre = A._edm_pre
edm_w = A._edm_w
fused_ce = A.fused_ce
run_block = A._run_block
# The live AGILLM-4.3 (~1.22B): d_model=1280, 28 layers, 20 heads, low-rank 160.
LIVE_CFG = {"d": 1280, "layers": 28, "heads": 20, "rank": 160}
# ----------------------------- helpers --------------------------------------
def even_split(L, B):
if int(L) % int(B) != 0:
raise ValueError(
f"even-split demo: L={L} not divisible by B={B}. The live "
f"_dblock_block_layers handles a remainder; that is out of scope here.")
return int(L) // int(B)
def sub_cfg(cfg, lb):
c = dict(cfg); c["layers"] = int(lb); return c
def build_block_model(cfg, B, device="cpu", tie_weights=True, attn_backend="sdpa"):
"""One DiffusionBlocks block as a standalone model: emb + L/B layers + ln + AR head."""
lb = even_split(int(cfg["layers"]), B)
core = Encoder(sub_cfg(cfg, lb), tie_weights=tie_weights, attn_backend=attn_backend).to(device)
ew = core.emb.weight if tie_weights else None
ar = ARHead(int(cfg["d"]), tie_weights=tie_weights, embedding_weight=ew).to(device)
return core, ar, lb
def stem_param_list(core, ar, tie_weights):
ps = list(core.emb.parameters()) + list(core.ln.parameters())
if not tie_weights:
ps += list(ar.parameters())
return ps
def set_freeze(params, frozen=True):
for p in params:
p.requires_grad = (not frozen)
def unique_trainable(*modules):
seen, out = set(), []
for m in modules:
for p in m.parameters():
if p.requires_grad and id(p) not in seen:
seen.add(id(p)); out.append(p)
return out
def _causal(T, device):
m = M.causal_mask(T, structured=False)
if torch.is_tensor(m):
m = m.to(device)
return m
def _denoise_hidden(emb_mod, ln_mod, layers, ids, sig, noise, attn_args=None):
"""Deterministic EDM-preconditioned forward through `layers` (mirrors _dblock_step).
`layers` is an iterable of Block modules. `noise` is supplied for reproducibility."""
device = ids.device
T = ids.size(1)
cs, co, ci = edm_pre(sig)
causal = _causal(T, device)
emb = emb_mod(ids)
zt = emb + sig[:, None, None] * noise
h = ci * zt
for blk in layers:
h = run_block(blk, h, causal, False, None, "off")
return ln_mod(cs * zt + co * h)
def dblock_ar_loss(core, ar, ids, lo, hi, wmax=5.0):
"""AR branch of agillm41._dblock_step, restricted to this block's sigma band."""
device = ids.device
u = torch.rand(ids.size(0), device=device)
sig = torch.exp(math.log(lo) + u * (math.log(hi) - math.log(lo))) # log-uniform in band
w = edm_w(sig, wmax)
noise = torch.randn(core.emb(ids).shape, device=device)
Dn = _denoise_hidden(core.emb, core.ln, core.blocks, ids, sig, noise)
raw = fused_ce(Dn[:, :-1], ar.proj.weight, ids[:, 1:])
return w * raw, float(raw.detach())
def opt_state_bytes(opt):
tot = 0
for st in opt.state.values():
for v in st.values():
if torch.is_tensor(v):
tot += v.numel() * v.element_size()
return tot
def trainable_bytes(params):
return sum(p.numel() * p.element_size() for p in params if p.requires_grad)
# ----------------------------- stem I/O -------------------------------------
def make_stem(cfg, B, out_path, device="cpu", tie_weights=True, attn_backend="sdpa", seed=0):
"""Create + save a shared stem (emb + ln [+ AR head if untied]).
In real use you would instead snapshot the stem from an existing checkpoint."""
torch.manual_seed(seed)
core, ar, lb = build_block_model(cfg, B, device, tie_weights, attn_backend)
payload = {
"kind": "diffusionblock_stem",
"cfg": dict(cfg), "tie_weights": bool(tie_weights),
"emb": core.emb.state_dict(), "ln": core.ln.state_dict(),
"ar": ar.state_dict(),
}
if out_path:
torch.save(payload, out_path)
return payload
def load_stem_into(core, ar, stem, tie_weights):
if isinstance(stem, str):
stem = torch.load(stem, map_location="cpu")
core.emb.load_state_dict(stem["emb"])
core.ln.load_state_dict(stem["ln"])
if not tie_weights:
ar.load_state_dict(stem["ar"])
def _load_ckpt(path, device="cpu"):
return torch.load(path, map_location=device, weights_only=False)
def _cfg_from_ckpt(ck):
if "cfg" in ck:
return dict(ck["cfg"])
if "seed_meta" in ck:
cfg = ck["seed_meta"].get("v4_preset") or ck["seed_meta"].get("v3_preset")
if cfg:
return dict(cfg)
raise KeyError("checkpoint has neither cfg nor seed_meta preset")
def _set_vocab_from_core_state(core_state):
if "emb.weight" in core_state:
A.VOCAB = int(core_state["emb.weight"].shape[0])
return int(A.VOCAB)
def _stem_payload_from_ckpt(ck):
core_state = ck["core"]
_set_vocab_from_core_state(core_state)
tie_weights = bool(ck.get("tie_weights", False))
payload = {
"kind": "diffusionblock_stem",
"cfg": _cfg_from_ckpt(ck),
"tie_weights": tie_weights,
"emb": {"weight": core_state["emb.weight"].detach().cpu()},
"ln": {
"weight": core_state["ln.weight"].detach().cpu(),
"bias": core_state["ln.bias"].detach().cpu(),
},
"ar": {k: v.detach().cpu() for k, v in ck.get("ar", {}).items()},
"source_step": int(ck.get("step", 0) or 0),
"source_seen_tok": int(ck.get("seen_tok", 0) or 0),
"tokenizer_id": ck.get("tokenizer_id"),
}
return payload
def save_stem_from_ckpt(ckpt_path, out_path):
payload = _stem_payload_from_ckpt(_load_ckpt(ckpt_path, device="cpu"))
torch.save(payload, out_path)
return payload
def _block_layer_indices(total_layers, B, block_index):
lb = even_split(int(total_layers), int(B))
start = int(block_index) * lb
return list(range(start, start + lb))
def load_ckpt_block_into(core, ar, ck, cfg, B, block_index, tie_weights):
core_state = ck["core"]
_set_vocab_from_core_state(core_state)
core.emb.load_state_dict({"weight": core_state["emb.weight"]})
core.ln.load_state_dict({"weight": core_state["ln.weight"], "bias": core_state["ln.bias"]})
if not tie_weights:
ar.load_state_dict(ck.get("ar", {}))
local_state = {}
for local_i, global_i in enumerate(_block_layer_indices(int(cfg["layers"]), B, block_index)):
src_prefix = f"blocks.{global_i}."
dst_prefix = f"{local_i}."
for key, value in core_state.items():
if key.startswith(src_prefix):
local_state[dst_prefix + key[len(src_prefix):]] = value
missing, unexpected = core.blocks.load_state_dict(local_state, strict=False)
if missing or unexpected:
raise RuntimeError(f"block slice load mismatch: missing={missing[:8]} unexpected={unexpected[:8]}")
def compose_into_checkpoint(base_ckpt, block_ckpts, out_path, device="cpu"):
ck = _load_ckpt(base_ckpt, device=device)
cfg = _cfg_from_ckpt(ck)
core_state = ck["core"]
payloads = [torch.load(p, map_location=device, weights_only=False) if isinstance(p, str) else p for p in block_ckpts]
metas = [d["meta"] for d in payloads]
B = int(metas[0]["B"])
lb = even_split(int(cfg["layers"]), B)
idxs = sorted(int(m["block_index"]) for m in metas)
if idxs != list(range(B)):
raise ValueError(f"compose-into-ckpt needs exactly blocks 0..{B-1}, got {idxs}")
for m in metas:
if int(m["B"]) != B:
raise ValueError("compose-into-ckpt: mismatched B across block checkpoints")
by_bi = {int(m["block_index"]): d for m, d in zip(metas, payloads)}
b0 = by_bi[0]
core_state["emb.weight"] = b0["emb"]["weight"].detach().cpu()
core_state["ln.weight"] = b0["ln"]["weight"].detach().cpu()
core_state["ln.bias"] = b0["ln"]["bias"].detach().cpu()
if not bool(ck.get("tie_weights", False)) and "ar" in ck:
ck["ar"] = {k: v.detach().cpu() for k, v in b0.get("ar", {}).items()}
for bi in range(B):
blk_sd = by_bi[bi]["blocks"]
for local_i in range(lb):
src_prefix = f"{local_i}."
global_i = bi * lb + local_i
dst_prefix = f"blocks.{global_i}."
for key, value in blk_sd.items():
if key.startswith(src_prefix):
core_state[dst_prefix + key[len(src_prefix):]] = value.detach().cpu()
ck["core"] = core_state
ck["diffusionblocks_independent"] = {
"kind": "diffusionblock_composed_checkpoint",
"B": B,
"layers_per_block": lb,
"source_base_ckpt": str(base_ckpt),
"block_count": len(block_ckpts),
"composed_at_unix": time.time(),
}
torch.save(ck, out_path)
return ck["diffusionblocks_independent"]
# ----------------------------- train one block ------------------------------
def random_batch(batch, seqlen, vocab, device):
return torch.randint(0, int(vocab), (int(batch), int(seqlen)), device=device)
def train_block(cfg, B, block_index, *, steps=60, batch=4, seqlen=32, lr=3e-4,
tie_weights=True, device="cpu", attn_backend="sdpa",
stem=None, init_ckpt=None, freeze_stem=True, out_path=None, log_every=0,
data_fn=None, seed=0):
"""Train exactly ONE block independently. Embarrassingly parallel: run B of
these (one per block_index) on B machines, no communication."""
torch.manual_seed(1000 + seed + block_index)
init_payload = _load_ckpt(init_ckpt, device="cpu") if init_ckpt is not None else None
if init_payload is not None:
_set_vocab_from_core_state(init_payload["core"])
sig = block_sigmas(B)
lo, hi = sorted([sig[block_index], sig[block_index + 1]])
core, ar, lb = build_block_model(cfg, B, device, tie_weights, attn_backend)
if init_payload is not None:
load_ckpt_block_into(core, ar, init_payload, cfg, B, block_index, tie_weights)
elif stem is not None:
load_stem_into(core, ar, stem, tie_weights)
if freeze_stem:
set_freeze(stem_param_list(core, ar, tie_weights), True)
core.train()
train_params = unique_trainable(core, ar)
opt = torch.optim.AdamW(train_params, lr=lr)
if data_fn is None:
data_fn = lambda: random_batch(batch, seqlen, A.VOCAB, device)
losses = []
for step in range(int(steps)):
ids = data_fn()
opt.zero_grad(set_to_none=True)
loss, raw = dblock_ar_loss(core, ar, ids, lo, hi)
loss.backward()
nn.utils.clip_grad_norm_(train_params, 1.0)
opt.step()
losses.append(raw)
if log_every and (step % log_every == 0 or step == steps - 1):
print(f"[block {block_index}/{B}] step {step:4d} sigma[{lo:.3f},{hi:.3f}] ar_ce={raw:.4f}", flush=True)
info = {
"block_index": int(block_index), "B": int(B), "layers_per_block": int(lb),
"sigma_lo": float(lo), "sigma_hi": float(hi),
"loss_first": float(losses[0]) if losses else None,
"loss_last": float(losses[-1]) if losses else None,
"trainable_param_bytes": int(trainable_bytes(train_params)),
"optimizer_state_bytes": int(opt_state_bytes(opt)),
"n_trainable_params": int(sum(p.numel() for p in train_params)),
}
if out_path:
torch.save({
"kind": "diffusionblock_independent",
"meta": {"block_index": int(block_index), "B": int(B),
"layers_per_block": int(lb), "base_cfg": dict(cfg),
"tie_weights": bool(tie_weights), "attn_backend": attn_backend,
"sigma_lo": float(lo), "sigma_hi": float(hi)},
"blocks": core.blocks.state_dict(),
"emb": core.emb.state_dict(), "ln": core.ln.state_dict(),
"ar": ar.state_dict(),
}, out_path)
info["out_path"] = out_path
return info, (core, ar)
# ----------------------------- compose --------------------------------------
def compose(block_ckpts, out_path=None, device="cpu"):
"""Reassemble B independently-trained blocks into ONE unified L-layer Encoder."""
payloads = [torch.load(p, map_location=device) if isinstance(p, str) else p for p in block_ckpts]
metas = [d["meta"] for d in payloads]
B = metas[0]["B"]; cfg = dict(metas[0]["base_cfg"]); tie = metas[0]["tie_weights"]
ab = metas[0].get("attn_backend", "sdpa"); lb = even_split(int(cfg["layers"]), B)
idxs = sorted(m["block_index"] for m in metas)
if idxs != list(range(B)):
raise ValueError(f"compose needs exactly blocks 0..{B-1}, got {idxs}")
for m in metas:
if m["B"] != B or dict(m["base_cfg"]) != cfg:
raise ValueError("compose: mismatched B / base_cfg across block checkpoints")
full = Encoder(cfg, tie_weights=tie, attn_backend=ab).to(device)
ew = full.emb.weight if tie else None
ar = ARHead(int(cfg["d"]), tie_weights=tie, embedding_weight=ew).to(device)
by_bi = {m["block_index"]: d for m, d in zip(metas, payloads)}
# shared stem: identical across blocks (all trained frozen to the same stem) -> take block 0
b0 = by_bi[0]
full.emb.load_state_dict(b0["emb"]); full.ln.load_state_dict(b0["ln"])
if not tie:
ar.load_state_dict(b0["ar"])
# slot each block's L/B layers into the full stack
for bi in range(B):
blk_sd = by_bi[bi]["blocks"]
for li in range(lb):
pref = f"{li}."
sub = {k[len(pref):]: v for k, v in blk_sd.items() if k.startswith(pref)}
full.blocks[bi * lb + li].load_state_dict(sub)
if out_path:
torch.save({"kind": "diffusionblock_composed", "cfg": cfg, "tie_weights": tie,
"attn_backend": ab, "B": B, "core": full.state_dict(),
"ar": ar.state_dict()}, out_path)
return full, ar, cfg, B, lb
# ----------------------------- memory report --------------------------------
def _measure_per_layer_params(cfg):
"""Build a 1-layer Encoder with a tiny vocab (vocab-independent layer count)."""
saved = A.VOCAB
try:
A.VOCAB = 8
enc = Encoder(sub_cfg(cfg, 1), tie_weights=False, attn_backend="sdpa")
p_layer = sum(p.numel() for p in enc.blocks.parameters())
p_ln = sum(p.numel() for p in enc.ln.parameters())
del enc
finally:
A.VOCAB = saved
return int(p_layer), int(p_ln)
def mem_report(cfg=None, vocab=None, Bs=(1, 2, 4, 7, 14, 28), bytes_per=4):
cfg = dict(cfg or LIVE_CFG)
vocab = int(vocab or A.VOCAB)
d, L = int(cfg["d"]), int(cfg["layers"])
p_layer, p_ln = _measure_per_layer_params(cfg)
p_emb = vocab * d # tied: AR head shares this
p_stem = p_emb + p_ln # shared, frozen during independent training
p_full = p_stem + L * p_layer
GB = lambda n: n * bytes_per / 1e9
print(f"# DiffusionBlocks memory model cfg={cfg} vocab={vocab} (fp32, {bytes_per}B/elem)")
print(f"# per-layer params = {p_layer:,} shared stem (emb+ln, tied head) = {p_stem:,}")
print(f"# FULL model params = {p_full:,} ({GB(p_full):.2f} GB params)")
print(f"# Monolith TRAIN mem ~ 4*P_full (param+grad+2*Adam) = {GB(4*p_full):.2f} GB (+ activations A*L)")
print()
hdr = ("B", "L/B", "trainable P", "4*train (param+grad+adam)", "frozen stem", "resident vs full 4P", "speedup")
print("{:>3} {:>4} {:>14} {:>26} {:>13} {:>20} {:>8}".format(*hdr))
for B in Bs:
if L % B:
continue
lb = L // B
p_train = lb * p_layer
train_mem = 4 * p_train # param + grad + 2 Adam moments (trainable only)
stem_mem = p_stem # frozen: params only, no grad/opt
resident_4p = train_mem + stem_mem # vs monolith 4*p_full
factor = (4 * p_full) / resident_4p
print("{:>3} {:>4} {:>14,} {:>23.2f} GB {:>10.2f} GB {:>16.2f} GB {:>7.2f}x".format(
B, lb, p_train, GB(train_mem), GB(stem_mem), GB(resident_4p), factor))
print()
print("# 'speedup' = monolith 4*P_full / one-block-resident (4*trainable + frozen stem).")
print("# The 4P (param+grad+Adam) term scales ~1/B; the frozen shared stem is the floor")
print("# (itself offloadable/shardable). Activations scale 1/B too and combine with")
print("# --grad_checkpoint (already in agillm41) for the paper's 4/3-time, 1/B-memory point.")
# ----------------------------- self test ------------------------------------
def selftest():
torch.manual_seed(0)
dev = "cpu"
cfg = {"d": 64, "layers": 8, "heads": 4, "rank": 16}
B = 4
saved_vocab = A.VOCAB
ok = True
try:
A.VOCAB = 256 # tiny vocab -> fast/small selftest
print(f"[selftest] cfg={cfg} B={B} vocab={A.VOCAB}")
# 1) shared stem
stem = make_stem(cfg, B, None, device=dev, tie_weights=True, seed=7)
# 2) train each block INDEPENDENTLY (frozen stem)
blocks, infos = [], []
for bi in range(B):
info, _ = train_block(cfg, B, bi, steps=40, batch=4, seqlen=24, lr=5e-3,
tie_weights=True, device=dev, stem=stem,
freeze_stem=True, out_path=f"/tmp/_dbi_b{bi}.pt", seed=3)
infos.append(info); blocks.append(f"/tmp/_dbi_b{bi}.pt")
dec = info["loss_first"] - info["loss_last"]
print(f" block {bi}: ar_ce {info['loss_first']:.4f} -> {info['loss_last']:.4f} "
f"(drop {dec:+.4f}), trainable={info['n_trainable_params']:,} "
f"opt_state={info['optimizer_state_bytes']/1e6:.2f}MB")
ok &= (info["loss_last"] < info["loss_first"]) # independent training reduces its band loss
# 3) memory: one block's trainable+opt vs a MONOLITHIC full model
full_core = Encoder(cfg, tie_weights=True, attn_backend="sdpa").to(dev)
full_ar = ARHead(cfg["d"], tie_weights=True, embedding_weight=full_core.emb.weight).to(dev)
full_params = unique_trainable(full_core, full_ar)
full_opt = torch.optim.AdamW(full_params, lr=1e-3)
# one adam step so optimizer state exists
ids = random_batch(4, 24, A.VOCAB, dev)
sigs = block_sigmas(B)
l, _ = dblock_ar_loss(full_core, full_ar, ids, sigs[0], sigs[-1]); l.backward(); full_opt.step()
full_opt_mb = opt_state_bytes(full_opt) / 1e6
one_opt_mb = infos[0]["optimizer_state_bytes"] / 1e6
ratio = full_opt_mb / max(one_opt_mb, 1e-9)
print(f"[selftest] optimizer-state full={full_opt_mb:.2f}MB one-block={one_opt_mb:.2f}MB ratio={ratio:.2f}x (target ~{B}x on layer params)")
# full opt state includes the (large, tied) emb head too; on the LAYER portion the ratio ~ B.
ok &= (one_opt_mb < full_opt_mb)
# 4) compose -> unified model
full, ar, ccfg, cB, lb = compose(blocks, out_path="/tmp/_dbi_full.pt", device=dev)
full.eval()
print(f"[selftest] composed full Encoder: layers={ccfg['layers']} from B={cB} x lb={lb}")
# 5) EXACT round-trip: composed full's block-slice forward == standalone block forward
max_diff = 0.0
for bi in range(B):
torch.manual_seed(500 + bi)
ids = random_batch(2, 20, A.VOCAB, dev)
sgl = block_sigmas(B); lo, hi = sgl[bi], sgl[bi + 1]
sig = torch.full((ids.size(0),), float((lo * hi) ** 0.5), device=dev)
noise = torch.randn(full.emb(ids).shape, device=dev)
# standalone block (reload its ckpt)
d = torch.load(blocks[bi], map_location=dev)
sc = Encoder(sub_cfg(cfg, lb), tie_weights=True, attn_backend="sdpa").to(dev)
sc.emb.load_state_dict(d["emb"]); sc.ln.load_state_dict(d["ln"]); sc.blocks.load_state_dict(d["blocks"])
sc.eval()
Dn_standalone = _denoise_hidden(sc.emb, sc.ln, sc.blocks, ids, sig, noise)
# composed full, using only this block's layer slice
sl = [full.blocks[bi * lb + j] for j in range(lb)]
Dn_composed = _denoise_hidden(full.emb, full.ln, sl, ids, sig, noise)
diff = float((Dn_standalone - Dn_composed).abs().max())
max_diff = max(max_diff, diff)
print(f"[selftest] compose round-trip max|Δhidden| = {max_diff:.2e} (want < 1e-4)")
ok &= (max_diff < 1e-4)
print()
print("=" * 70)
print(f"[selftest] {'PASS' if ok else 'FAIL'}")
print("=" * 70)
finally:
A.VOCAB = saved_vocab
return 0 if ok else 1
# ----------------------------- CLI ------------------------------------------
def _cfg_from_args(a):
return {"d": a.d, "layers": a.layers, "heads": a.heads, "rank": a.rank}
def main():
ap = argparse.ArgumentParser(description="AGILLM-4.3 DiffusionBlocks independent training")
sub = ap.add_subparsers(dest="cmd", required=True)
s = sub.add_parser("selftest")
m = sub.add_parser("mem-report")
for name, dv in (("d", 1280), ("layers", 28), ("heads", 20), ("rank", 160)):
m.add_argument(f"--{name}", type=int, default=dv)
m.add_argument("--vocab", type=int, default=0)
m.add_argument("--ckpt", default=None, help="Read cfg/vocab from an AGILLM checkpoint")
sc = sub.add_parser("stem-from-ckpt")
sc.add_argument("--ckpt", required=True)
sc.add_argument("--out", required=True)
mk = sub.add_parser("make-stem")
for name, dv in (("d", 1280), ("layers", 28), ("heads", 20), ("rank", 160)):
mk.add_argument(f"--{name}", type=int, default=dv)
mk.add_argument("--B", type=int, required=True)
mk.add_argument("--out", required=True)
mk.add_argument("--no-tie", action="store_true")
tb = sub.add_parser("train-block")
for name, dv in (("d", 1280), ("layers", 28), ("heads", 20), ("rank", 160)):
tb.add_argument(f"--{name}", type=int, default=dv)
tb.add_argument("--B", type=int, required=True)
tb.add_argument("--block", type=int, required=True)
tb.add_argument("--stem", default=None)
tb.add_argument("--init-ckpt", default=None, help="Initialize stem and this block slice from a full AGILLM checkpoint")
tb.add_argument("--out", required=True)
tb.add_argument("--steps", type=int, default=200)
tb.add_argument("--batch", type=int, default=4)
tb.add_argument("--seqlen", type=int, default=64)
tb.add_argument("--lr", type=float, default=3e-4)
tb.add_argument("--no-tie", action="store_true")
tb.add_argument("--no-freeze-stem", action="store_true")
tb.add_argument("--device", default="cpu")
tb.add_argument("--log-every", type=int, default=20)
cp = sub.add_parser("compose")
cp.add_argument("--blocks", nargs="+", required=True)
cp.add_argument("--out", required=True)
cp.add_argument("--device", default="cpu")
cic = sub.add_parser("compose-into-ckpt")
cic.add_argument("--base-ckpt", required=True)
cic.add_argument("--blocks", nargs="+", required=True)
cic.add_argument("--out", required=True)
cic.add_argument("--device", default="cpu")
a = ap.parse_args()
if a.cmd == "selftest":
sys.exit(selftest())
if a.cmd == "mem-report":
if a.ckpt:
ck = _load_ckpt(a.ckpt, device="cpu")
mem_report(_cfg_from_ckpt(ck), vocab=(a.vocab or _set_vocab_from_core_state(ck["core"])))
else:
mem_report(_cfg_from_args(a), vocab=(a.vocab or None))
return
if a.cmd == "stem-from-ckpt":
payload = save_stem_from_ckpt(a.ckpt, a.out)
print(f"[stem-from-ckpt] saved stem -> {a.out} step={payload.get('source_step')}"); return
if a.cmd == "make-stem":
info = make_stem(_cfg_from_args(a), a.B, a.out, tie_weights=not a.no_tie)
print(f"[make-stem] saved stem -> {a.out}"); return
if a.cmd == "train-block":
cfg = _cfg_from_args(a)
tie_weights = not a.no_tie
if a.init_ckpt:
ck = _load_ckpt(a.init_ckpt, device="cpu")
cfg = _cfg_from_ckpt(ck)
tie_weights = bool(ck.get("tie_weights", tie_weights))
_set_vocab_from_core_state(ck["core"])
info, _ = train_block(cfg, a.B, a.block, steps=a.steps, batch=a.batch,
seqlen=a.seqlen, lr=a.lr, tie_weights=tie_weights,
device=a.device, stem=a.stem, init_ckpt=a.init_ckpt,
freeze_stem=not a.no_freeze_stem,
out_path=a.out, log_every=a.log_every)
print(json.dumps(info, indent=2)); return
if a.cmd == "compose":
full, ar, cfg, B, lb = compose(a.blocks, out_path=a.out, device=a.device)
print(f"[compose] {B} blocks x {lb} layers -> full {cfg['layers']}-layer model saved {a.out}"); return
if a.cmd == "compose-into-ckpt":
meta = compose_into_checkpoint(a.base_ckpt, a.blocks, a.out, device=a.device)
print(f"[compose-into-ckpt] B={meta['B']} x {meta['layers_per_block']} layers -> {a.out}"); return
if __name__ == "__main__":
main()

Xet Storage Details

Size:
27.9 kB
·
Xet hash:
1739855aaa3708e2800c535f0fc9d5082cdafafc61bf106f627168158299b193

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.