Spaces:
Running on Zero
Running on Zero
File size: 11,157 Bytes
73dd4cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | """
train_qa_link.py -- train a question->answer latent bridge (the upgrade the panel's
"Tell Math a secret" demo points at).
Task: an arithmetic question ("23 + 54 =") is shown ONLY to the frozen Math/reasoning
specialist, which encodes it to its 256-d output latent. A NEW RecursiveLink + a
fine-tuned Language asker must emit the ANSWER ("077", zero-padded digits) reading
nothing but that latent: asker input is just "ANS> " + answer digits (teacher-forced
in training; decoded autoregressively at eval). 8% of (a, op, b) problems are HELD OUT
of training, so eval accuracy on them is generalization, not memorization. Ablating
the latent removes the question entirely -> accuracy collapses to the digit prior.
Saves links/qa__language__from__reasoning.safetensors in the same key style as the
key-recall bridge (link./ali./asker. + metadata) for moe_gradio.py to load.
Run: python agents/modmind/train_qa_link.py [--steps 4000] [--device cuda]
"""
from __future__ import annotations
import argparse
import hashlib
import json
import os
import random
import sys
import time
import torch
import torch.nn.functional as F
HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)
from model import RecursiveLink, SpikeWhaleLM # noqa: E402
from specialist_presets import specialist_config # noqa: E402
from spike_tokenizer import SpikeTokenizer # noqa: E402
ASKER, CONSULTANT = "language", "reasoning"
D_LATENT = 256
PROMPT = "ANS> "
ANS_LEN = 3 # answers zero-padded to 3 digits ("077")
HOLDOUT_PCT = 8 # % of problems held out of training entirely
OUT = os.path.join(HERE, "links", f"qa__{ASKER}__from__{CONSULTANT}.safetensors")
# ---- the problem space --------------------------------------------------------
def all_problems():
"""Every (a, op, b) the bridge is trained/evaluated on. Answers are 0..198."""
probs = []
for a in range(10, 100):
for b in range(10, 100):
probs.append((a, "+", b))
if a >= b:
probs.append((a, "-", b))
for a in range(2, 13):
for b in range(2, 13):
probs.append((a, "*", b))
return probs
def answer(a, op, b):
return {"+": a + b, "-": a - b, "*": a * b}[op]
def is_holdout(a, op, b, pct):
if pct <= 0:
return False
h = hashlib.md5(f"{a}{op}{b}".encode()).digest()[0]
return h % 100 < pct
def render(a, op, b):
return f"{a} {op} {b} ="
# ---- model loading (same pattern as moe_gradio.py) ------------------------------
def load_specialist(domain, device):
from safetensors.torch import load_file
ck = os.path.join(HERE, domain, "checkpoints", "model.safetensors")
cfg = specialist_config(domain)
m = SpikeWhaleLM(cfg).to(device)
sd = load_file(ck, device=device)
sd = {k: (v.float() if v.is_floating_point() else v) for k, v in sd.items()}
m.load_state_dict(sd)
tok = SpikeTokenizer(vocab_file=os.path.join(HERE, domain, "tokenizer.json"))
return m, tok
# ---- training -------------------------------------------------------------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--steps", type=int, default=4000)
ap.add_argument("--batch", type=int, default=128)
ap.add_argument("--link-lr", type=float, default=1e-3)
ap.add_argument("--asker-lr", type=float, default=1e-4)
ap.add_argument("--asker-wd", type=float, default=0.0)
ap.add_argument("--holdout", type=int, default=0,
help="%% of problems held out of training (0 = train on ALL, the lookup-table demo)")
ap.add_argument("--eval-every", type=int, default=200)
ap.add_argument("--eval-n", type=int, default=256)
ap.add_argument("--eval-chunk", type=int, default=64) # keep eval VRAM peaks small
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
ap.add_argument("--seed", type=int, default=0)
ap.add_argument("--fresh", action="store_true", help="ignore last.pt and start over")
args = ap.parse_args()
dev = args.device
random.seed(args.seed); torch.manual_seed(args.seed)
print(f"[qa-link] device={dev}", flush=True)
consultant, c_tok = load_specialist(CONSULTANT, dev)
asker, a_tok = load_specialist(ASKER, dev)
consultant.eval()
for p in consultant.parameters():
p.requires_grad_(False)
# answer digits must be single tokens for the asker (position-aligned readout)
digit_ids = []
for d in "0123456789":
ids = a_tok.encode(d, add_special_tokens=False)
assert len(ids) == 1, f"digit {d!r} is not a single token: {ids}"
digit_ids.append(ids[0])
prompt_ids = a_tok.encode(PROMPT, add_special_tokens=False)
plen = len(prompt_ids)
print(f"[qa-link] prompt {PROMPT!r} = {plen} tokens; digits map to single tokens", flush=True)
link = RecursiveLink(d_latent=D_LATENT).to(dev)
opt = torch.optim.AdamW([
{"params": list(link.parameters()), "lr": args.link_lr, "weight_decay": 0.0},
{"params": list(asker.parameters()), "lr": args.asker_lr, "weight_decay": args.asker_wd},
])
probs = all_problems()
train_pool = [p for p in probs if not is_holdout(*p, args.holdout)]
eval_pool = [p for p in probs if is_holdout(*p, args.holdout)]
memorize = args.holdout <= 0
if memorize:
eval_pool = train_pool # no holdout: "accuracy" = coverage of the whole table
print(f"[qa-link] MEMORIZE mode: training on ALL {len(train_pool)} problems (no holdout)", flush=True)
else:
print(f"[qa-link] {len(train_pool)} train problems, {len(eval_pool)} held out", flush=True)
label = "accuracy" if memorize else "held-out exact"
@torch.no_grad()
def encode_questions(batch):
"""Frozen consultant -> latents. Bucketed by token length (latent is a
mean-pool over positions, so padding would corrupt it)."""
idss = [c_tok.encode(render(*p), add_special_tokens=False) for p in batch]
lat = torch.zeros(len(batch), D_LATENT, device=dev)
by_len = {}
for i, ids in enumerate(idss):
by_len.setdefault(len(ids), []).append(i)
for L, idx in by_len.items():
c_ids = torch.tensor([idss[i] for i in idx], device=dev)
lat[idx] = consultant(input_ids=c_ids).latent
return lat
def ans_tokens(p):
return [digit_ids[int(ch)] for ch in f"{answer(*p):0{ANS_LEN}d}"]
@torch.no_grad()
def evaluate(pool, n, ablate=False):
"""Autoregressive 3-digit decode (full-vocab argmax, no teacher forcing).
Chunked to keep VRAM peaks small."""
asker.eval()
sample = random.sample(pool, min(n, len(pool)))
hit_e = hit_d = 0
for o in range(0, len(sample), args.eval_chunk):
chunk = sample[o:o + args.eval_chunk]
lat = encode_questions(chunk)
inj = torch.zeros_like(link(lat)) if ablate else link(lat)
ids = torch.tensor([prompt_ids] * len(chunk), device=dev)
for _ in range(ANS_LEN):
logits = asker(input_ids=ids, inject_latent=inj).logits[:, -1, :]
ids = torch.cat([ids, logits.argmax(-1, keepdim=True)], dim=1)
pred = ids[:, plen:]
tgt = torch.tensor([ans_tokens(p) for p in chunk], device=dev)
hit_e += int((pred == tgt).all(dim=1).sum())
hit_d += int((pred == tgt).sum())
asker.train()
return hit_e / len(sample), hit_d / (len(sample) * ANS_LEN)
# resume from last.pt if a previous run died mid-flight
last_pt = OUT + ".last.pt"
best, start_step = -1.0, 0
if os.path.exists(last_pt) and not args.fresh:
st = torch.load(last_pt, map_location=dev, weights_only=False)
link.load_state_dict(st["link"]); asker.load_state_dict(st["asker"])
opt.load_state_dict(st["opt"]); best, start_step = st["best"], st["step"]
print(f"[qa-link] resumed from step {start_step} (best held-out {best*100:.1f}%)", flush=True)
t0 = time.time()
asker.train()
for step in range(start_step + 1, args.steps + 1):
batch = random.sample(train_pool, args.batch)
lat = encode_questions(batch)
inj = link(lat)
a_ids = torch.tensor([prompt_ids + ans_tokens(p) for p in batch], device=dev)
labels = a_ids.clone()
labels[:, :plen] = -100 # loss only on the answer digits
out = asker(input_ids=a_ids, labels=labels, inject_latent=inj)
opt.zero_grad(); out.loss.backward(); opt.step()
if step % args.eval_every == 0 or step == args.steps:
ex, pd = evaluate(eval_pool, args.eval_n)
extra = "" if memorize else f" train exact {evaluate(train_pool, args.eval_n)[0]*100:5.1f}%"
print(f"[qa-link] step {step:5d} loss {out.loss.item():.4f} "
f"{label} {ex*100:5.1f}% (digits {pd*100:5.1f}%){extra} "
f"[{time.time()-t0:.0f}s]", flush=True)
if ex > best:
best = ex
save(link, asker, ex, step, args, memorize)
print(f"[qa-link] saved -> {OUT} ({label} {ex*100:.1f}%)", flush=True)
# resume checkpoint every eval, so a crash never loses more than eval_every steps
torch.save({"link": link.state_dict(), "asker": asker.state_dict(),
"opt": opt.state_dict(), "best": best, "step": step}, last_pt + ".tmp")
os.replace(last_pt + ".tmp", last_pt)
# final ablation numbers from the BEST saved bridge are written at save();
# report the last-step ablation here for the log.
ex_a, pd_a = evaluate(eval_pool, args.eval_n, ablate=True)
print(f"[qa-link] ablated (latent cut): exact {ex_a*100:.1f}% / digits {pd_a*100:.1f}%", flush=True)
print(f"[qa-link] done. best {label} {best*100:.1f}%", flush=True)
def save(link, asker, acc, step, args, memorize):
from safetensors.torch import save_file
os.makedirs(os.path.dirname(OUT), exist_ok=True)
t = {}
for k, v in link.state_dict().items():
t["link." + k] = v.detach().to("cpu", torch.float16).contiguous()
for k, v in asker.model.latent_inject.state_dict().items():
t["ali." + k] = v.detach().to("cpu", torch.float16).contiguous()
for k, v in asker.state_dict().items():
t["asker." + k] = (v.detach().to("cpu", torch.float16).contiguous()
if v.is_floating_point() else v.detach().cpu().contiguous())
tmp = OUT + ".tmp"
save_file(t, tmp, metadata={
"kind": "qa", "ans_len": str(ANS_LEN), "prompt": PROMPT,
"asker": ASKER, "consultant": CONSULTANT,
"mode": "memorize" if memorize else "generalize",
"holdout_pct": str(args.holdout), "step": str(step),
# accuracy over the whole table (memorize) or held-out set (generalize)
"holdout_exact": f"{acc:.4f}",
"ops": json.dumps(["+", "-", "*"]),
})
os.replace(tmp, OUT) # atomic: the panel hot-reloads this file while we train
if __name__ == "__main__":
main()
|