ModuleMind / agents /modmind /train_qa_link.py
Quazim0t0's picture
Upload 7 files
73dd4cf verified
Raw
History Blame Contribute Delete
11.2 kB
"""
train_qa_link.py -- train a question->answer latent bridge (the upgrade the panel's
"Tell Math a secret" demo points at).
Task: an arithmetic question ("23 + 54 =") is shown ONLY to the frozen Math/reasoning
specialist, which encodes it to its 256-d output latent. A NEW RecursiveLink + a
fine-tuned Language asker must emit the ANSWER ("077", zero-padded digits) reading
nothing but that latent: asker input is just "ANS> " + answer digits (teacher-forced
in training; decoded autoregressively at eval). 8% of (a, op, b) problems are HELD OUT
of training, so eval accuracy on them is generalization, not memorization. Ablating
the latent removes the question entirely -> accuracy collapses to the digit prior.
Saves links/qa__language__from__reasoning.safetensors in the same key style as the
key-recall bridge (link./ali./asker. + metadata) for moe_gradio.py to load.
Run: python agents/modmind/train_qa_link.py [--steps 4000] [--device cuda]
"""
from __future__ import annotations
import argparse
import hashlib
import json
import os
import random
import sys
import time
import torch
import torch.nn.functional as F
HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)
from model import RecursiveLink, SpikeWhaleLM # noqa: E402
from specialist_presets import specialist_config # noqa: E402
from spike_tokenizer import SpikeTokenizer # noqa: E402
ASKER, CONSULTANT = "language", "reasoning"
D_LATENT = 256
PROMPT = "ANS> "
ANS_LEN = 3 # answers zero-padded to 3 digits ("077")
HOLDOUT_PCT = 8 # % of problems held out of training entirely
OUT = os.path.join(HERE, "links", f"qa__{ASKER}__from__{CONSULTANT}.safetensors")
# ---- the problem space --------------------------------------------------------
def all_problems():
"""Every (a, op, b) the bridge is trained/evaluated on. Answers are 0..198."""
probs = []
for a in range(10, 100):
for b in range(10, 100):
probs.append((a, "+", b))
if a >= b:
probs.append((a, "-", b))
for a in range(2, 13):
for b in range(2, 13):
probs.append((a, "*", b))
return probs
def answer(a, op, b):
return {"+": a + b, "-": a - b, "*": a * b}[op]
def is_holdout(a, op, b, pct):
if pct <= 0:
return False
h = hashlib.md5(f"{a}{op}{b}".encode()).digest()[0]
return h % 100 < pct
def render(a, op, b):
return f"{a} {op} {b} ="
# ---- model loading (same pattern as moe_gradio.py) ------------------------------
def load_specialist(domain, device):
from safetensors.torch import load_file
ck = os.path.join(HERE, domain, "checkpoints", "model.safetensors")
cfg = specialist_config(domain)
m = SpikeWhaleLM(cfg).to(device)
sd = load_file(ck, device=device)
sd = {k: (v.float() if v.is_floating_point() else v) for k, v in sd.items()}
m.load_state_dict(sd)
tok = SpikeTokenizer(vocab_file=os.path.join(HERE, domain, "tokenizer.json"))
return m, tok
# ---- training -------------------------------------------------------------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--steps", type=int, default=4000)
ap.add_argument("--batch", type=int, default=128)
ap.add_argument("--link-lr", type=float, default=1e-3)
ap.add_argument("--asker-lr", type=float, default=1e-4)
ap.add_argument("--asker-wd", type=float, default=0.0)
ap.add_argument("--holdout", type=int, default=0,
help="%% of problems held out of training (0 = train on ALL, the lookup-table demo)")
ap.add_argument("--eval-every", type=int, default=200)
ap.add_argument("--eval-n", type=int, default=256)
ap.add_argument("--eval-chunk", type=int, default=64) # keep eval VRAM peaks small
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--fresh", action="store_true", help="ignore last.pt and start over")
args = ap.parse_args()
dev = args.device
random.seed(args.seed); torch.manual_seed(args.seed)
print(f"[qa-link] device={dev}", flush=True)
consultant, c_tok = load_specialist(CONSULTANT, dev)
asker, a_tok = load_specialist(ASKER, dev)
consultant.eval()
for p in consultant.parameters():
p.requires_grad_(False)
# answer digits must be single tokens for the asker (position-aligned readout)
digit_ids = []
for d in "0123456789":
ids = a_tok.encode(d, add_special_tokens=False)
assert len(ids) == 1, f"digit {d!r} is not a single token: {ids}"
digit_ids.append(ids[0])
prompt_ids = a_tok.encode(PROMPT, add_special_tokens=False)
plen = len(prompt_ids)
print(f"[qa-link] prompt {PROMPT!r} = {plen} tokens; digits map to single tokens", flush=True)
link = RecursiveLink(d_latent=D_LATENT).to(dev)
opt = torch.optim.AdamW([
{"params": list(link.parameters()), "lr": args.link_lr, "weight_decay": 0.0},
{"params": list(asker.parameters()), "lr": args.asker_lr, "weight_decay": args.asker_wd},
])
probs = all_problems()
train_pool = [p for p in probs if not is_holdout(*p, args.holdout)]
eval_pool = [p for p in probs if is_holdout(*p, args.holdout)]
memorize = args.holdout <= 0
if memorize:
eval_pool = train_pool # no holdout: "accuracy" = coverage of the whole table
print(f"[qa-link] MEMORIZE mode: training on ALL {len(train_pool)} problems (no holdout)", flush=True)
else:
print(f"[qa-link] {len(train_pool)} train problems, {len(eval_pool)} held out", flush=True)
label = "accuracy" if memorize else "held-out exact"
@torch.no_grad()
def encode_questions(batch):
"""Frozen consultant -> latents. Bucketed by token length (latent is a
mean-pool over positions, so padding would corrupt it)."""
idss = [c_tok.encode(render(*p), add_special_tokens=False) for p in batch]
lat = torch.zeros(len(batch), D_LATENT, device=dev)
by_len = {}
for i, ids in enumerate(idss):
by_len.setdefault(len(ids), []).append(i)
for L, idx in by_len.items():
c_ids = torch.tensor([idss[i] for i in idx], device=dev)
lat[idx] = consultant(input_ids=c_ids).latent
return lat
def ans_tokens(p):
return [digit_ids[int(ch)] for ch in f"{answer(*p):0{ANS_LEN}d}"]
@torch.no_grad()
def evaluate(pool, n, ablate=False):
"""Autoregressive 3-digit decode (full-vocab argmax, no teacher forcing).
Chunked to keep VRAM peaks small."""
asker.eval()
sample = random.sample(pool, min(n, len(pool)))
hit_e = hit_d = 0
for o in range(0, len(sample), args.eval_chunk):
chunk = sample[o:o + args.eval_chunk]
lat = encode_questions(chunk)
inj = torch.zeros_like(link(lat)) if ablate else link(lat)
ids = torch.tensor([prompt_ids] * len(chunk), device=dev)
for _ in range(ANS_LEN):
logits = asker(input_ids=ids, inject_latent=inj).logits[:, -1, :]
ids = torch.cat([ids, logits.argmax(-1, keepdim=True)], dim=1)
pred = ids[:, plen:]
tgt = torch.tensor([ans_tokens(p) for p in chunk], device=dev)
hit_e += int((pred == tgt).all(dim=1).sum())
hit_d += int((pred == tgt).sum())
asker.train()
return hit_e / len(sample), hit_d / (len(sample) * ANS_LEN)
# resume from last.pt if a previous run died mid-flight
last_pt = OUT + ".last.pt"
best, start_step = -1.0, 0
if os.path.exists(last_pt) and not args.fresh:
st = torch.load(last_pt, map_location=dev, weights_only=False)
link.load_state_dict(st["link"]); asker.load_state_dict(st["asker"])
opt.load_state_dict(st["opt"]); best, start_step = st["best"], st["step"]
print(f"[qa-link] resumed from step {start_step} (best held-out {best*100:.1f}%)", flush=True)
t0 = time.time()
asker.train()
for step in range(start_step + 1, args.steps + 1):
batch = random.sample(train_pool, args.batch)
lat = encode_questions(batch)
inj = link(lat)
a_ids = torch.tensor([prompt_ids + ans_tokens(p) for p in batch], device=dev)
labels = a_ids.clone()
labels[:, :plen] = -100 # loss only on the answer digits
out = asker(input_ids=a_ids, labels=labels, inject_latent=inj)
opt.zero_grad(); out.loss.backward(); opt.step()
if step % args.eval_every == 0 or step == args.steps:
ex, pd = evaluate(eval_pool, args.eval_n)
extra = "" if memorize else f" train exact {evaluate(train_pool, args.eval_n)[0]*100:5.1f}%"
print(f"[qa-link] step {step:5d} loss {out.loss.item():.4f} "
f"{label} {ex*100:5.1f}% (digits {pd*100:5.1f}%){extra} "
f"[{time.time()-t0:.0f}s]", flush=True)
if ex > best:
best = ex
save(link, asker, ex, step, args, memorize)
print(f"[qa-link] saved -> {OUT} ({label} {ex*100:.1f}%)", flush=True)
# resume checkpoint every eval, so a crash never loses more than eval_every steps
torch.save({"link": link.state_dict(), "asker": asker.state_dict(),
"opt": opt.state_dict(), "best": best, "step": step}, last_pt + ".tmp")
os.replace(last_pt + ".tmp", last_pt)
# final ablation numbers from the BEST saved bridge are written at save();
# report the last-step ablation here for the log.
ex_a, pd_a = evaluate(eval_pool, args.eval_n, ablate=True)
print(f"[qa-link] ablated (latent cut): exact {ex_a*100:.1f}% / digits {pd_a*100:.1f}%", flush=True)
print(f"[qa-link] done. best {label} {best*100:.1f}%", flush=True)
def save(link, asker, acc, step, args, memorize):
from safetensors.torch import save_file
os.makedirs(os.path.dirname(OUT), exist_ok=True)
t = {}
for k, v in link.state_dict().items():
t["link." + k] = v.detach().to("cpu", torch.float16).contiguous()
for k, v in asker.model.latent_inject.state_dict().items():
t["ali." + k] = v.detach().to("cpu", torch.float16).contiguous()
for k, v in asker.state_dict().items():
t["asker." + k] = (v.detach().to("cpu", torch.float16).contiguous()
if v.is_floating_point() else v.detach().cpu().contiguous())
tmp = OUT + ".tmp"
save_file(t, tmp, metadata={
"kind": "qa", "ans_len": str(ANS_LEN), "prompt": PROMPT,
"asker": ASKER, "consultant": CONSULTANT,
"mode": "memorize" if memorize else "generalize",
"holdout_pct": str(args.holdout), "step": str(step),
# accuracy over the whole table (memorize) or held-out set (generalize)
"holdout_exact": f"{acc:.4f}",
"ops": json.dumps(["+", "-", "*"]),
})
os.replace(tmp, OUT) # atomic: the panel hot-reloads this file while we train
if __name__ == "__main__":
main()