""" train_qa_link.py -- train a question->answer latent bridge (the upgrade the panel's "Tell Math a secret" demo points at). Task: an arithmetic question ("23 + 54 =") is shown ONLY to the frozen Math/reasoning specialist, which encodes it to its 256-d output latent. A NEW RecursiveLink + a fine-tuned Language asker must emit the ANSWER ("077", zero-padded digits) reading nothing but that latent: asker input is just "ANS> " + answer digits (teacher-forced in training; decoded autoregressively at eval). 8% of (a, op, b) problems are HELD OUT of training, so eval accuracy on them is generalization, not memorization. Ablating the latent removes the question entirely -> accuracy collapses to the digit prior. Saves links/qa__language__from__reasoning.safetensors in the same key style as the key-recall bridge (link./ali./asker. + metadata) for moe_gradio.py to load. Run: python agents/modmind/train_qa_link.py [--steps 4000] [--device cuda] """ from __future__ import annotations import argparse import hashlib import json import os import random import sys import time import torch import torch.nn.functional as F HERE = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, HERE) from model import RecursiveLink, SpikeWhaleLM # noqa: E402 from specialist_presets import specialist_config # noqa: E402 from spike_tokenizer import SpikeTokenizer # noqa: E402 ASKER, CONSULTANT = "language", "reasoning" D_LATENT = 256 PROMPT = "ANS> " ANS_LEN = 3 # answers zero-padded to 3 digits ("077") HOLDOUT_PCT = 8 # % of problems held out of training entirely OUT = os.path.join(HERE, "links", f"qa__{ASKER}__from__{CONSULTANT}.safetensors") # ---- the problem space -------------------------------------------------------- def all_problems(): """Every (a, op, b) the bridge is trained/evaluated on. Answers are 0..198.""" probs = [] for a in range(10, 100): for b in range(10, 100): probs.append((a, "+", b)) if a >= b: probs.append((a, "-", b)) for a in range(2, 13): for b in range(2, 13): probs.append((a, "*", b)) return probs def answer(a, op, b): return {"+": a + b, "-": a - b, "*": a * b}[op] def is_holdout(a, op, b, pct): if pct <= 0: return False h = hashlib.md5(f"{a}{op}{b}".encode()).digest()[0] return h % 100 < pct def render(a, op, b): return f"{a} {op} {b} =" # ---- model loading (same pattern as moe_gradio.py) ------------------------------ def load_specialist(domain, device): from safetensors.torch import load_file ck = os.path.join(HERE, domain, "checkpoints", "model.safetensors") cfg = specialist_config(domain) m = SpikeWhaleLM(cfg).to(device) sd = load_file(ck, device=device) sd = {k: (v.float() if v.is_floating_point() else v) for k, v in sd.items()} m.load_state_dict(sd) tok = SpikeTokenizer(vocab_file=os.path.join(HERE, domain, "tokenizer.json")) return m, tok # ---- training ------------------------------------------------------------------- def main(): ap = argparse.ArgumentParser() ap.add_argument("--steps", type=int, default=4000) ap.add_argument("--batch", type=int, default=128) ap.add_argument("--link-lr", type=float, default=1e-3) ap.add_argument("--asker-lr", type=float, default=1e-4) ap.add_argument("--asker-wd", type=float, default=0.0) ap.add_argument("--holdout", type=int, default=0, help="%% of problems held out of training (0 = train on ALL, the lookup-table demo)") ap.add_argument("--eval-every", type=int, default=200) ap.add_argument("--eval-n", type=int, default=256) ap.add_argument("--eval-chunk", type=int, default=64) # keep eval VRAM peaks small ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") ap.add_argument("--seed", type=int, default=0) ap.add_argument("--fresh", action="store_true", help="ignore last.pt and start over") args = ap.parse_args() dev = args.device random.seed(args.seed); torch.manual_seed(args.seed) print(f"[qa-link] device={dev}", flush=True) consultant, c_tok = load_specialist(CONSULTANT, dev) asker, a_tok = load_specialist(ASKER, dev) consultant.eval() for p in consultant.parameters(): p.requires_grad_(False) # answer digits must be single tokens for the asker (position-aligned readout) digit_ids = [] for d in "0123456789": ids = a_tok.encode(d, add_special_tokens=False) assert len(ids) == 1, f"digit {d!r} is not a single token: {ids}" digit_ids.append(ids[0]) prompt_ids = a_tok.encode(PROMPT, add_special_tokens=False) plen = len(prompt_ids) print(f"[qa-link] prompt {PROMPT!r} = {plen} tokens; digits map to single tokens", flush=True) link = RecursiveLink(d_latent=D_LATENT).to(dev) opt = torch.optim.AdamW([ {"params": list(link.parameters()), "lr": args.link_lr, "weight_decay": 0.0}, {"params": list(asker.parameters()), "lr": args.asker_lr, "weight_decay": args.asker_wd}, ]) probs = all_problems() train_pool = [p for p in probs if not is_holdout(*p, args.holdout)] eval_pool = [p for p in probs if is_holdout(*p, args.holdout)] memorize = args.holdout <= 0 if memorize: eval_pool = train_pool # no holdout: "accuracy" = coverage of the whole table print(f"[qa-link] MEMORIZE mode: training on ALL {len(train_pool)} problems (no holdout)", flush=True) else: print(f"[qa-link] {len(train_pool)} train problems, {len(eval_pool)} held out", flush=True) label = "accuracy" if memorize else "held-out exact" @torch.no_grad() def encode_questions(batch): """Frozen consultant -> latents. Bucketed by token length (latent is a mean-pool over positions, so padding would corrupt it).""" idss = [c_tok.encode(render(*p), add_special_tokens=False) for p in batch] lat = torch.zeros(len(batch), D_LATENT, device=dev) by_len = {} for i, ids in enumerate(idss): by_len.setdefault(len(ids), []).append(i) for L, idx in by_len.items(): c_ids = torch.tensor([idss[i] for i in idx], device=dev) lat[idx] = consultant(input_ids=c_ids).latent return lat def ans_tokens(p): return [digit_ids[int(ch)] for ch in f"{answer(*p):0{ANS_LEN}d}"] @torch.no_grad() def evaluate(pool, n, ablate=False): """Autoregressive 3-digit decode (full-vocab argmax, no teacher forcing). Chunked to keep VRAM peaks small.""" asker.eval() sample = random.sample(pool, min(n, len(pool))) hit_e = hit_d = 0 for o in range(0, len(sample), args.eval_chunk): chunk = sample[o:o + args.eval_chunk] lat = encode_questions(chunk) inj = torch.zeros_like(link(lat)) if ablate else link(lat) ids = torch.tensor([prompt_ids] * len(chunk), device=dev) for _ in range(ANS_LEN): logits = asker(input_ids=ids, inject_latent=inj).logits[:, -1, :] ids = torch.cat([ids, logits.argmax(-1, keepdim=True)], dim=1) pred = ids[:, plen:] tgt = torch.tensor([ans_tokens(p) for p in chunk], device=dev) hit_e += int((pred == tgt).all(dim=1).sum()) hit_d += int((pred == tgt).sum()) asker.train() return hit_e / len(sample), hit_d / (len(sample) * ANS_LEN) # resume from last.pt if a previous run died mid-flight last_pt = OUT + ".last.pt" best, start_step = -1.0, 0 if os.path.exists(last_pt) and not args.fresh: st = torch.load(last_pt, map_location=dev, weights_only=False) link.load_state_dict(st["link"]); asker.load_state_dict(st["asker"]) opt.load_state_dict(st["opt"]); best, start_step = st["best"], st["step"] print(f"[qa-link] resumed from step {start_step} (best held-out {best*100:.1f}%)", flush=True) t0 = time.time() asker.train() for step in range(start_step + 1, args.steps + 1): batch = random.sample(train_pool, args.batch) lat = encode_questions(batch) inj = link(lat) a_ids = torch.tensor([prompt_ids + ans_tokens(p) for p in batch], device=dev) labels = a_ids.clone() labels[:, :plen] = -100 # loss only on the answer digits out = asker(input_ids=a_ids, labels=labels, inject_latent=inj) opt.zero_grad(); out.loss.backward(); opt.step() if step % args.eval_every == 0 or step == args.steps: ex, pd = evaluate(eval_pool, args.eval_n) extra = "" if memorize else f" train exact {evaluate(train_pool, args.eval_n)[0]*100:5.1f}%" print(f"[qa-link] step {step:5d} loss {out.loss.item():.4f} " f"{label} {ex*100:5.1f}% (digits {pd*100:5.1f}%){extra} " f"[{time.time()-t0:.0f}s]", flush=True) if ex > best: best = ex save(link, asker, ex, step, args, memorize) print(f"[qa-link] saved -> {OUT} ({label} {ex*100:.1f}%)", flush=True) # resume checkpoint every eval, so a crash never loses more than eval_every steps torch.save({"link": link.state_dict(), "asker": asker.state_dict(), "opt": opt.state_dict(), "best": best, "step": step}, last_pt + ".tmp") os.replace(last_pt + ".tmp", last_pt) # final ablation numbers from the BEST saved bridge are written at save(); # report the last-step ablation here for the log. ex_a, pd_a = evaluate(eval_pool, args.eval_n, ablate=True) print(f"[qa-link] ablated (latent cut): exact {ex_a*100:.1f}% / digits {pd_a*100:.1f}%", flush=True) print(f"[qa-link] done. best {label} {best*100:.1f}%", flush=True) def save(link, asker, acc, step, args, memorize): from safetensors.torch import save_file os.makedirs(os.path.dirname(OUT), exist_ok=True) t = {} for k, v in link.state_dict().items(): t["link." + k] = v.detach().to("cpu", torch.float16).contiguous() for k, v in asker.model.latent_inject.state_dict().items(): t["ali." + k] = v.detach().to("cpu", torch.float16).contiguous() for k, v in asker.state_dict().items(): t["asker." + k] = (v.detach().to("cpu", torch.float16).contiguous() if v.is_floating_point() else v.detach().cpu().contiguous()) tmp = OUT + ".tmp" save_file(t, tmp, metadata={ "kind": "qa", "ans_len": str(ANS_LEN), "prompt": PROMPT, "asker": ASKER, "consultant": CONSULTANT, "mode": "memorize" if memorize else "generalize", "holdout_pct": str(args.holdout), "step": str(step), # accuracy over the whole table (memorize) or held-out set (generalize) "holdout_exact": f"{acc:.4f}", "ops": json.dumps(["+", "-", "*"]), }) os.replace(tmp, OUT) # atomic: the panel hot-reloads this file while we train if __name__ == "__main__": main()