Spaces:
Running on Zero
Running on Zero
| """ | |
| train_qa_link.py -- train a question->answer latent bridge (the upgrade the panel's | |
| "Tell Math a secret" demo points at). | |
| Task: an arithmetic question ("23 + 54 =") is shown ONLY to the frozen Math/reasoning | |
| specialist, which encodes it to its 256-d output latent. A NEW RecursiveLink + a | |
| fine-tuned Language asker must emit the ANSWER ("077", zero-padded digits) reading | |
| nothing but that latent: asker input is just "ANS> " + answer digits (teacher-forced | |
| in training; decoded autoregressively at eval). 8% of (a, op, b) problems are HELD OUT | |
| of training, so eval accuracy on them is generalization, not memorization. Ablating | |
| the latent removes the question entirely -> accuracy collapses to the digit prior. | |
| Saves links/qa__language__from__reasoning.safetensors in the same key style as the | |
| key-recall bridge (link./ali./asker. + metadata) for moe_gradio.py to load. | |
| Run: python agents/modmind/train_qa_link.py [--steps 4000] [--device cuda] | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import hashlib | |
| import json | |
| import os | |
| import random | |
| import sys | |
| import time | |
| import torch | |
| import torch.nn.functional as F | |
| HERE = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.insert(0, HERE) | |
| from model import RecursiveLink, SpikeWhaleLM # noqa: E402 | |
| from specialist_presets import specialist_config # noqa: E402 | |
| from spike_tokenizer import SpikeTokenizer # noqa: E402 | |
| ASKER, CONSULTANT = "language", "reasoning" | |
| D_LATENT = 256 | |
| PROMPT = "ANS> " | |
| ANS_LEN = 3 # answers zero-padded to 3 digits ("077") | |
| HOLDOUT_PCT = 8 # % of problems held out of training entirely | |
| OUT = os.path.join(HERE, "links", f"qa__{ASKER}__from__{CONSULTANT}.safetensors") | |
| # ---- the problem space -------------------------------------------------------- | |
| def all_problems(): | |
| """Every (a, op, b) the bridge is trained/evaluated on. Answers are 0..198.""" | |
| probs = [] | |
| for a in range(10, 100): | |
| for b in range(10, 100): | |
| probs.append((a, "+", b)) | |
| if a >= b: | |
| probs.append((a, "-", b)) | |
| for a in range(2, 13): | |
| for b in range(2, 13): | |
| probs.append((a, "*", b)) | |
| return probs | |
| def answer(a, op, b): | |
| return {"+": a + b, "-": a - b, "*": a * b}[op] | |
| def is_holdout(a, op, b, pct): | |
| if pct <= 0: | |
| return False | |
| h = hashlib.md5(f"{a}{op}{b}".encode()).digest()[0] | |
| return h % 100 < pct | |
| def render(a, op, b): | |
| return f"{a} {op} {b} =" | |
| # ---- model loading (same pattern as moe_gradio.py) ------------------------------ | |
| def load_specialist(domain, device): | |
| from safetensors.torch import load_file | |
| ck = os.path.join(HERE, domain, "checkpoints", "model.safetensors") | |
| cfg = specialist_config(domain) | |
| m = SpikeWhaleLM(cfg).to(device) | |
| sd = load_file(ck, device=device) | |
| sd = {k: (v.float() if v.is_floating_point() else v) for k, v in sd.items()} | |
| m.load_state_dict(sd) | |
| tok = SpikeTokenizer(vocab_file=os.path.join(HERE, domain, "tokenizer.json")) | |
| return m, tok | |
| # ---- training ------------------------------------------------------------------- | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--steps", type=int, default=4000) | |
| ap.add_argument("--batch", type=int, default=128) | |
| ap.add_argument("--link-lr", type=float, default=1e-3) | |
| ap.add_argument("--asker-lr", type=float, default=1e-4) | |
| ap.add_argument("--asker-wd", type=float, default=0.0) | |
| ap.add_argument("--holdout", type=int, default=0, | |
| help="%% of problems held out of training (0 = train on ALL, the lookup-table demo)") | |
| ap.add_argument("--eval-every", type=int, default=200) | |
| ap.add_argument("--eval-n", type=int, default=256) | |
| ap.add_argument("--eval-chunk", type=int, default=64) # keep eval VRAM peaks small | |
| ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") | |
| ap.add_argument("--seed", type=int, default=0) | |
| ap.add_argument("--fresh", action="store_true", help="ignore last.pt and start over") | |
| args = ap.parse_args() | |
| dev = args.device | |
| random.seed(args.seed); torch.manual_seed(args.seed) | |
| print(f"[qa-link] device={dev}", flush=True) | |
| consultant, c_tok = load_specialist(CONSULTANT, dev) | |
| asker, a_tok = load_specialist(ASKER, dev) | |
| consultant.eval() | |
| for p in consultant.parameters(): | |
| p.requires_grad_(False) | |
| # answer digits must be single tokens for the asker (position-aligned readout) | |
| digit_ids = [] | |
| for d in "0123456789": | |
| ids = a_tok.encode(d, add_special_tokens=False) | |
| assert len(ids) == 1, f"digit {d!r} is not a single token: {ids}" | |
| digit_ids.append(ids[0]) | |
| prompt_ids = a_tok.encode(PROMPT, add_special_tokens=False) | |
| plen = len(prompt_ids) | |
| print(f"[qa-link] prompt {PROMPT!r} = {plen} tokens; digits map to single tokens", flush=True) | |
| link = RecursiveLink(d_latent=D_LATENT).to(dev) | |
| opt = torch.optim.AdamW([ | |
| {"params": list(link.parameters()), "lr": args.link_lr, "weight_decay": 0.0}, | |
| {"params": list(asker.parameters()), "lr": args.asker_lr, "weight_decay": args.asker_wd}, | |
| ]) | |
| probs = all_problems() | |
| train_pool = [p for p in probs if not is_holdout(*p, args.holdout)] | |
| eval_pool = [p for p in probs if is_holdout(*p, args.holdout)] | |
| memorize = args.holdout <= 0 | |
| if memorize: | |
| eval_pool = train_pool # no holdout: "accuracy" = coverage of the whole table | |
| print(f"[qa-link] MEMORIZE mode: training on ALL {len(train_pool)} problems (no holdout)", flush=True) | |
| else: | |
| print(f"[qa-link] {len(train_pool)} train problems, {len(eval_pool)} held out", flush=True) | |
| label = "accuracy" if memorize else "held-out exact" | |
| def encode_questions(batch): | |
| """Frozen consultant -> latents. Bucketed by token length (latent is a | |
| mean-pool over positions, so padding would corrupt it).""" | |
| idss = [c_tok.encode(render(*p), add_special_tokens=False) for p in batch] | |
| lat = torch.zeros(len(batch), D_LATENT, device=dev) | |
| by_len = {} | |
| for i, ids in enumerate(idss): | |
| by_len.setdefault(len(ids), []).append(i) | |
| for L, idx in by_len.items(): | |
| c_ids = torch.tensor([idss[i] for i in idx], device=dev) | |
| lat[idx] = consultant(input_ids=c_ids).latent | |
| return lat | |
| def ans_tokens(p): | |
| return [digit_ids[int(ch)] for ch in f"{answer(*p):0{ANS_LEN}d}"] | |
| def evaluate(pool, n, ablate=False): | |
| """Autoregressive 3-digit decode (full-vocab argmax, no teacher forcing). | |
| Chunked to keep VRAM peaks small.""" | |
| asker.eval() | |
| sample = random.sample(pool, min(n, len(pool))) | |
| hit_e = hit_d = 0 | |
| for o in range(0, len(sample), args.eval_chunk): | |
| chunk = sample[o:o + args.eval_chunk] | |
| lat = encode_questions(chunk) | |
| inj = torch.zeros_like(link(lat)) if ablate else link(lat) | |
| ids = torch.tensor([prompt_ids] * len(chunk), device=dev) | |
| for _ in range(ANS_LEN): | |
| logits = asker(input_ids=ids, inject_latent=inj).logits[:, -1, :] | |
| ids = torch.cat([ids, logits.argmax(-1, keepdim=True)], dim=1) | |
| pred = ids[:, plen:] | |
| tgt = torch.tensor([ans_tokens(p) for p in chunk], device=dev) | |
| hit_e += int((pred == tgt).all(dim=1).sum()) | |
| hit_d += int((pred == tgt).sum()) | |
| asker.train() | |
| return hit_e / len(sample), hit_d / (len(sample) * ANS_LEN) | |
| # resume from last.pt if a previous run died mid-flight | |
| last_pt = OUT + ".last.pt" | |
| best, start_step = -1.0, 0 | |
| if os.path.exists(last_pt) and not args.fresh: | |
| st = torch.load(last_pt, map_location=dev, weights_only=False) | |
| link.load_state_dict(st["link"]); asker.load_state_dict(st["asker"]) | |
| opt.load_state_dict(st["opt"]); best, start_step = st["best"], st["step"] | |
| print(f"[qa-link] resumed from step {start_step} (best held-out {best*100:.1f}%)", flush=True) | |
| t0 = time.time() | |
| asker.train() | |
| for step in range(start_step + 1, args.steps + 1): | |
| batch = random.sample(train_pool, args.batch) | |
| lat = encode_questions(batch) | |
| inj = link(lat) | |
| a_ids = torch.tensor([prompt_ids + ans_tokens(p) for p in batch], device=dev) | |
| labels = a_ids.clone() | |
| labels[:, :plen] = -100 # loss only on the answer digits | |
| out = asker(input_ids=a_ids, labels=labels, inject_latent=inj) | |
| opt.zero_grad(); out.loss.backward(); opt.step() | |
| if step % args.eval_every == 0 or step == args.steps: | |
| ex, pd = evaluate(eval_pool, args.eval_n) | |
| extra = "" if memorize else f" train exact {evaluate(train_pool, args.eval_n)[0]*100:5.1f}%" | |
| print(f"[qa-link] step {step:5d} loss {out.loss.item():.4f} " | |
| f"{label} {ex*100:5.1f}% (digits {pd*100:5.1f}%){extra} " | |
| f"[{time.time()-t0:.0f}s]", flush=True) | |
| if ex > best: | |
| best = ex | |
| save(link, asker, ex, step, args, memorize) | |
| print(f"[qa-link] saved -> {OUT} ({label} {ex*100:.1f}%)", flush=True) | |
| # resume checkpoint every eval, so a crash never loses more than eval_every steps | |
| torch.save({"link": link.state_dict(), "asker": asker.state_dict(), | |
| "opt": opt.state_dict(), "best": best, "step": step}, last_pt + ".tmp") | |
| os.replace(last_pt + ".tmp", last_pt) | |
| # final ablation numbers from the BEST saved bridge are written at save(); | |
| # report the last-step ablation here for the log. | |
| ex_a, pd_a = evaluate(eval_pool, args.eval_n, ablate=True) | |
| print(f"[qa-link] ablated (latent cut): exact {ex_a*100:.1f}% / digits {pd_a*100:.1f}%", flush=True) | |
| print(f"[qa-link] done. best {label} {best*100:.1f}%", flush=True) | |
| def save(link, asker, acc, step, args, memorize): | |
| from safetensors.torch import save_file | |
| os.makedirs(os.path.dirname(OUT), exist_ok=True) | |
| t = {} | |
| for k, v in link.state_dict().items(): | |
| t["link." + k] = v.detach().to("cpu", torch.float16).contiguous() | |
| for k, v in asker.model.latent_inject.state_dict().items(): | |
| t["ali." + k] = v.detach().to("cpu", torch.float16).contiguous() | |
| for k, v in asker.state_dict().items(): | |
| t["asker." + k] = (v.detach().to("cpu", torch.float16).contiguous() | |
| if v.is_floating_point() else v.detach().cpu().contiguous()) | |
| tmp = OUT + ".tmp" | |
| save_file(t, tmp, metadata={ | |
| "kind": "qa", "ans_len": str(ANS_LEN), "prompt": PROMPT, | |
| "asker": ASKER, "consultant": CONSULTANT, | |
| "mode": "memorize" if memorize else "generalize", | |
| "holdout_pct": str(args.holdout), "step": str(step), | |
| # accuracy over the whole table (memorize) or held-out set (generalize) | |
| "holdout_exact": f"{acc:.4f}", | |
| "ops": json.dumps(["+", "-", "*"]), | |
| }) | |
| os.replace(tmp, OUT) # atomic: the panel hot-reloads this file while we train | |
| if __name__ == "__main__": | |
| main() | |