"""Quick test: does our trained model still get GSM8K answers right when forced to emit ONLY the final answer (no verbal CoT in y)? This is the *latent-reasoning value proposition* test: if the K=16 latent vectors actually carry the reasoning, the model should be able to produce "#### N" directly without writing out steps. We don't retrain — we just generate with max_new_tokens small enough that no verbal CoT can fit. Two flavors: short-32: max_new_tokens=32, just enough for "#### NUMBER" + a few extras short-8: max_new_tokens=8, basically forces direct emission The model was trained to emit a full ~150-token GSM8K answer, so we expect this is way out-of-distribution and accuracy will be low. The point is to calibrate: is there ANY chance z carries the reasoning, or is the model hopelessly dependent on writing verbal CoT? """ from __future__ import annotations import argparse import json import re import time from pathlib import Path import torch from torch.utils.data import DataLoader from ..data import GSM8KDataset, collate_batch from ..model import BLTConfig, LatentProjector, build_base, generate_with_latent GSM8K_NUM = re.compile(r"####\s*(-?\d+(?:\.\d+)?)") ANY_NUM = re.compile(r"-?\d+(?:\.\d+)?") def parse_pred(text: str): m = GSM8K_NUM.search(text) if m: return m.group(1) nums = ANY_NUM.findall(text) return nums[-1] if nums else None def correct(pred, gold): if pred is None: return False try: return abs(float(pred) - float(gold)) < 1e-4 except ValueError: return False def main(): p = argparse.ArgumentParser() p.add_argument("--ckpt", required=True) p.add_argument("--config", required=True) p.add_argument("--n", type=int, default=100) p.add_argument("--K", type=int, default=16) p.add_argument("--no_block_y_to_x", action="store_true") p.add_argument("--max_new_tokens", type=int, default=32) p.add_argument("--out", default=None) args = p.parse_args() with open(args.config) as f: cfg = json.load(f) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt = Path(args.ckpt) bcfg = BLTConfig( base_model=cfg["base_model"], use_lora=False, lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"], lora_dropout=cfg["lora_dropout"], lora_target_modules=tuple(cfg["lora_target_modules"]), K_latents=args.K, block_y_to_x=cfg["block_y_to_x"], proj_init_scale=cfg["proj_init_scale"], dtype=cfg["dtype"], attn_impl=cfg["attn_impl"], gradient_checkpointing=False, ) base_model, tokenizer = build_base(bcfg) from peft import PeftModel adapter_dir = ckpt / "model" if (adapter_dir / "adapter_config.json").exists(): model = PeftModel.from_pretrained(base_model, str(adapter_dir)) print(f"[load] adapter from {adapter_dir}") else: model = base_model model.to(device).eval() inner = model.get_base_model() if hasattr(model, "get_base_model") else model d_model = inner.config.hidden_size projector = LatentProjector( d_model, init_scale=cfg["proj_init_scale"], use_mlp=cfg.get("proj_mlp", False), hidden_mult=cfg.get("proj_hidden_mult", 4), ).to(device).to(next(model.parameters()).dtype) projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device)) projector.eval() val_ds = GSM8KDataset(split="test", max_examples=args.n) loader = DataLoader( val_ds, batch_size=8, shuffle=False, collate_fn=lambda b: collate_batch(b, tokenizer, max_prompt_len=cfg["max_prompt_len"], max_answer_len=cfg["max_answer_len"]), ) block_y_to_x = not args.no_block_y_to_x print(f"[mode] block_y_to_x={block_y_to_x} max_new_tokens={args.max_new_tokens} K={args.K}") correct_n = 0 total = 0 examples = [] t0 = time.time() for batch in loader: x_ids = batch.x_ids.to(device); x_attn = batch.x_attn.to(device) B = x_ids.size(0) gen = generate_with_latent( model, tokenizer, projector, x_ids=x_ids, x_attn=x_attn, K=args.K, block_y_to_x=block_y_to_x, max_new_tokens=args.max_new_tokens, temperature=0.0, eos_token_id=tokenizer.eos_token_id, ) for b in range(B): text = tokenizer.decode(gen[b], skip_special_tokens=True) pred = parse_pred(text) gold = batch.final_strs[b].replace("#### ", "").strip() ok = correct(pred, gold) correct_n += int(ok) total += 1 if len(examples) < 8: examples.append({"text": text[:120], "pred": pred, "gold": gold, "ok": ok}) summary = { "ckpt": str(ckpt), "n": args.n, "K": args.K, "block_y_to_x": block_y_to_x, "max_new_tokens": args.max_new_tokens, "acc": correct_n / max(total, 1), "correct": correct_n, "total": total, "elapsed_s": time.time() - t0, "examples": examples, } out = args.out or str(ckpt / f"short_y_eval_M{args.max_new_tokens}_block{block_y_to_x}.json") Path(out).write_text(json.dumps(summary, indent=2)) print(f"[done] acc={summary['acc']:.4f} ({correct_n}/{total}) elapsed={summary['elapsed_s']:.0f}s") print(f"[written] {out}") for e in examples[:5]: print(f" text={e['text']!r} pred={e['pred']} gold={e['gold']} ok={e['ok']}") if __name__ == "__main__": main()