blt-reasoner-pilot1 / code /scripts /eval_short_y.py
LauraGG's picture
Refresh code/ with latest BLT-Reasoner sources (post-campaign)
bc7101b verified
"""Quick test: does our trained model still get GSM8K answers right when forced
to emit ONLY the final answer (no verbal CoT in y)?
This is the *latent-reasoning value proposition* test: if the K=16 latent
vectors actually carry the reasoning, the model should be able to produce
"#### N" directly without writing out steps.
We don't retrain — we just generate with max_new_tokens small enough that
no verbal CoT can fit. Two flavors:
short-32: max_new_tokens=32, just enough for "#### NUMBER" + a few extras
short-8: max_new_tokens=8, basically forces direct emission
The model was trained to emit a full ~150-token GSM8K answer, so we expect
this is way out-of-distribution and accuracy will be low. The point is to
calibrate: is there ANY chance z carries the reasoning, or is the model
hopelessly dependent on writing verbal CoT?
"""
from __future__ import annotations
import argparse
import json
import re
import time
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from ..data import GSM8KDataset, collate_batch
from ..model import BLTConfig, LatentProjector, build_base, generate_with_latent
GSM8K_NUM = re.compile(r"####\s*(-?\d+(?:\.\d+)?)")
ANY_NUM = re.compile(r"-?\d+(?:\.\d+)?")
def parse_pred(text: str):
m = GSM8K_NUM.search(text)
if m: return m.group(1)
nums = ANY_NUM.findall(text)
return nums[-1] if nums else None
def correct(pred, gold):
if pred is None: return False
try: return abs(float(pred) - float(gold)) < 1e-4
except ValueError: return False
def main():
p = argparse.ArgumentParser()
p.add_argument("--ckpt", required=True)
p.add_argument("--config", required=True)
p.add_argument("--n", type=int, default=100)
p.add_argument("--K", type=int, default=16)
p.add_argument("--no_block_y_to_x", action="store_true")
p.add_argument("--max_new_tokens", type=int, default=32)
p.add_argument("--out", default=None)
args = p.parse_args()
with open(args.config) as f:
cfg = json.load(f)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = Path(args.ckpt)
bcfg = BLTConfig(
base_model=cfg["base_model"], use_lora=False,
lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"],
lora_dropout=cfg["lora_dropout"],
lora_target_modules=tuple(cfg["lora_target_modules"]),
K_latents=args.K, block_y_to_x=cfg["block_y_to_x"],
proj_init_scale=cfg["proj_init_scale"],
dtype=cfg["dtype"], attn_impl=cfg["attn_impl"],
gradient_checkpointing=False,
)
base_model, tokenizer = build_base(bcfg)
from peft import PeftModel
adapter_dir = ckpt / "model"
if (adapter_dir / "adapter_config.json").exists():
model = PeftModel.from_pretrained(base_model, str(adapter_dir))
print(f"[load] adapter from {adapter_dir}")
else:
model = base_model
model.to(device).eval()
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
d_model = inner.config.hidden_size
projector = LatentProjector(
d_model, init_scale=cfg["proj_init_scale"],
use_mlp=cfg.get("proj_mlp", False),
hidden_mult=cfg.get("proj_hidden_mult", 4),
).to(device).to(next(model.parameters()).dtype)
projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device))
projector.eval()
val_ds = GSM8KDataset(split="test", max_examples=args.n)
loader = DataLoader(
val_ds, batch_size=8, shuffle=False,
collate_fn=lambda b: collate_batch(b, tokenizer,
max_prompt_len=cfg["max_prompt_len"],
max_answer_len=cfg["max_answer_len"]),
)
block_y_to_x = not args.no_block_y_to_x
print(f"[mode] block_y_to_x={block_y_to_x} max_new_tokens={args.max_new_tokens} K={args.K}")
correct_n = 0
total = 0
examples = []
t0 = time.time()
for batch in loader:
x_ids = batch.x_ids.to(device); x_attn = batch.x_attn.to(device)
B = x_ids.size(0)
gen = generate_with_latent(
model, tokenizer, projector,
x_ids=x_ids, x_attn=x_attn, K=args.K,
block_y_to_x=block_y_to_x,
max_new_tokens=args.max_new_tokens,
temperature=0.0, eos_token_id=tokenizer.eos_token_id,
)
for b in range(B):
text = tokenizer.decode(gen[b], skip_special_tokens=True)
pred = parse_pred(text)
gold = batch.final_strs[b].replace("#### ", "").strip()
ok = correct(pred, gold)
correct_n += int(ok)
total += 1
if len(examples) < 8:
examples.append({"text": text[:120], "pred": pred, "gold": gold, "ok": ok})
summary = {
"ckpt": str(ckpt), "n": args.n, "K": args.K,
"block_y_to_x": block_y_to_x,
"max_new_tokens": args.max_new_tokens,
"acc": correct_n / max(total, 1),
"correct": correct_n,
"total": total,
"elapsed_s": time.time() - t0,
"examples": examples,
}
out = args.out or str(ckpt / f"short_y_eval_M{args.max_new_tokens}_block{block_y_to_x}.json")
Path(out).write_text(json.dumps(summary, indent=2))
print(f"[done] acc={summary['acc']:.4f} ({correct_n}/{total}) elapsed={summary['elapsed_s']:.0f}s")
print(f"[written] {out}")
for e in examples[:5]:
print(f" text={e['text']!r} pred={e['pred']} gold={e['gold']} ok={e['ok']}")
if __name__ == "__main__":
main()