| """Quick test: does our trained model still get GSM8K answers right when forced |
| to emit ONLY the final answer (no verbal CoT in y)? |
| |
| This is the *latent-reasoning value proposition* test: if the K=16 latent |
| vectors actually carry the reasoning, the model should be able to produce |
| "#### N" directly without writing out steps. |
| |
| We don't retrain — we just generate with max_new_tokens small enough that |
| no verbal CoT can fit. Two flavors: |
| short-32: max_new_tokens=32, just enough for "#### NUMBER" + a few extras |
| short-8: max_new_tokens=8, basically forces direct emission |
| |
| The model was trained to emit a full ~150-token GSM8K answer, so we expect |
| this is way out-of-distribution and accuracy will be low. The point is to |
| calibrate: is there ANY chance z carries the reasoning, or is the model |
| hopelessly dependent on writing verbal CoT? |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import re |
| import time |
| from pathlib import Path |
|
|
| import torch |
| from torch.utils.data import DataLoader |
|
|
| from ..data import GSM8KDataset, collate_batch |
| from ..model import BLTConfig, LatentProjector, build_base, generate_with_latent |
|
|
|
|
| GSM8K_NUM = re.compile(r"####\s*(-?\d+(?:\.\d+)?)") |
| ANY_NUM = re.compile(r"-?\d+(?:\.\d+)?") |
|
|
|
|
| def parse_pred(text: str): |
| m = GSM8K_NUM.search(text) |
| if m: return m.group(1) |
| nums = ANY_NUM.findall(text) |
| return nums[-1] if nums else None |
|
|
|
|
| def correct(pred, gold): |
| if pred is None: return False |
| try: return abs(float(pred) - float(gold)) < 1e-4 |
| except ValueError: return False |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--ckpt", required=True) |
| p.add_argument("--config", required=True) |
| p.add_argument("--n", type=int, default=100) |
| p.add_argument("--K", type=int, default=16) |
| p.add_argument("--no_block_y_to_x", action="store_true") |
| p.add_argument("--max_new_tokens", type=int, default=32) |
| p.add_argument("--out", default=None) |
| args = p.parse_args() |
|
|
| with open(args.config) as f: |
| cfg = json.load(f) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| ckpt = Path(args.ckpt) |
| bcfg = BLTConfig( |
| base_model=cfg["base_model"], use_lora=False, |
| lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"], |
| lora_dropout=cfg["lora_dropout"], |
| lora_target_modules=tuple(cfg["lora_target_modules"]), |
| K_latents=args.K, block_y_to_x=cfg["block_y_to_x"], |
| proj_init_scale=cfg["proj_init_scale"], |
| dtype=cfg["dtype"], attn_impl=cfg["attn_impl"], |
| gradient_checkpointing=False, |
| ) |
| base_model, tokenizer = build_base(bcfg) |
| from peft import PeftModel |
| adapter_dir = ckpt / "model" |
| if (adapter_dir / "adapter_config.json").exists(): |
| model = PeftModel.from_pretrained(base_model, str(adapter_dir)) |
| print(f"[load] adapter from {adapter_dir}") |
| else: |
| model = base_model |
| model.to(device).eval() |
| inner = model.get_base_model() if hasattr(model, "get_base_model") else model |
| d_model = inner.config.hidden_size |
| projector = LatentProjector( |
| d_model, init_scale=cfg["proj_init_scale"], |
| use_mlp=cfg.get("proj_mlp", False), |
| hidden_mult=cfg.get("proj_hidden_mult", 4), |
| ).to(device).to(next(model.parameters()).dtype) |
| projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device)) |
| projector.eval() |
|
|
| val_ds = GSM8KDataset(split="test", max_examples=args.n) |
| loader = DataLoader( |
| val_ds, batch_size=8, shuffle=False, |
| collate_fn=lambda b: collate_batch(b, tokenizer, |
| max_prompt_len=cfg["max_prompt_len"], |
| max_answer_len=cfg["max_answer_len"]), |
| ) |
|
|
| block_y_to_x = not args.no_block_y_to_x |
| print(f"[mode] block_y_to_x={block_y_to_x} max_new_tokens={args.max_new_tokens} K={args.K}") |
|
|
| correct_n = 0 |
| total = 0 |
| examples = [] |
| t0 = time.time() |
| for batch in loader: |
| x_ids = batch.x_ids.to(device); x_attn = batch.x_attn.to(device) |
| B = x_ids.size(0) |
| gen = generate_with_latent( |
| model, tokenizer, projector, |
| x_ids=x_ids, x_attn=x_attn, K=args.K, |
| block_y_to_x=block_y_to_x, |
| max_new_tokens=args.max_new_tokens, |
| temperature=0.0, eos_token_id=tokenizer.eos_token_id, |
| ) |
| for b in range(B): |
| text = tokenizer.decode(gen[b], skip_special_tokens=True) |
| pred = parse_pred(text) |
| gold = batch.final_strs[b].replace("#### ", "").strip() |
| ok = correct(pred, gold) |
| correct_n += int(ok) |
| total += 1 |
| if len(examples) < 8: |
| examples.append({"text": text[:120], "pred": pred, "gold": gold, "ok": ok}) |
| summary = { |
| "ckpt": str(ckpt), "n": args.n, "K": args.K, |
| "block_y_to_x": block_y_to_x, |
| "max_new_tokens": args.max_new_tokens, |
| "acc": correct_n / max(total, 1), |
| "correct": correct_n, |
| "total": total, |
| "elapsed_s": time.time() - t0, |
| "examples": examples, |
| } |
| out = args.out or str(ckpt / f"short_y_eval_M{args.max_new_tokens}_block{block_y_to_x}.json") |
| Path(out).write_text(json.dumps(summary, indent=2)) |
| print(f"[done] acc={summary['acc']:.4f} ({correct_n}/{total}) elapsed={summary['elapsed_s']:.0f}s") |
| print(f"[written] {out}") |
| for e in examples[:5]: |
| print(f" text={e['text']!r} pred={e['pred']} gold={e['gold']} ok={e['ok']}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|