blt-reasoner-pilot1 / code /scripts /eval_base_gsm8k.py
LauraGG's picture
Refresh code/ with latest BLT-Reasoner sources (post-campaign)
bc7101b verified
"""E1: base-model ceiling eval — Qwen2.5-Math-7B-Instruct + standard CoT
prompting on GSM8K-test n=200 AR. No latent loop, no projector, no LoRA, no
bottleneck. Establishes the verbal-CoT upper bound our BLT-Reasoner is
trying to approach.
Uses the SAME 200 test problems and same accuracy parsing as our BLT eval
for a clean side-by-side comparison.
"""
from __future__ import annotations
import argparse
import json
import re
import time
from pathlib import Path
from typing import Optional
import torch
from torch.utils.data import DataLoader
from ..data import GSM8KDataset, collate_batch
COT_PROMPT_TEMPLATE = (
"Please solve the following math problem step by step. End your response "
"with '#### N' where N is the final numerical answer.\n\n"
"Question: {question}\nAnswer:"
)
GSM8K_NUM = re.compile(r"####\s*(-?\d+(?:\.\d+)?)")
ANY_NUM = re.compile(r"-?\d+(?:\.\d+)?")
def parse_pred(text: str) -> Optional[str]:
m = GSM8K_NUM.search(text)
if m:
return m.group(1)
nums = ANY_NUM.findall(text)
return nums[-1] if nums else None
def correct(pred: Optional[str], gold: str) -> bool:
if pred is None:
return False
try:
return abs(float(pred) - float(gold)) < 1e-4
except ValueError:
return False
def main():
p = argparse.ArgumentParser()
p.add_argument("--base_model", default="Qwen/Qwen2.5-Math-7B-Instruct")
p.add_argument("--n", type=int, default=200)
p.add_argument("--max_new_tokens", type=int, default=384)
p.add_argument("--temperature", type=float, default=0.0)
p.add_argument("--out", default="/home/ubuntu/work/base_qwen_gsm8k.json")
args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from transformers import AutoModelForCausalLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
print(f"[load] {args.base_model}")
model = AutoModelForCausalLM.from_pretrained(
args.base_model, torch_dtype=torch.bfloat16,
attn_implementation="eager", trust_remote_code=True,
).to(device).eval()
model.config.use_cache = True
print("[load] done")
ds = GSM8KDataset(split="test", max_examples=args.n)
print(f"[data] n={len(ds)}")
correct_n = 0
total = 0
examples_log = []
t0 = time.time()
# Batched generation (batch=4 for memory safety with max_new=384)
bs = 4
for i in range(0, len(ds), bs):
batch_items = [ds[j] for j in range(i, min(i + bs, len(ds)))]
prompts = [COT_PROMPT_TEMPLATE.format(question=ex["question"]) for ex in batch_items]
enc = tok(prompts, return_tensors="pt", padding=True, padding_side="left",
truncation=True, max_length=512).to(device)
with torch.no_grad():
gen = model.generate(
input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
max_new_tokens=args.max_new_tokens,
do_sample=False if args.temperature <= 0 else True,
temperature=max(args.temperature, 1e-6),
pad_token_id=tok.pad_token_id,
eos_token_id=tok.eos_token_id,
)
gen_only = gen[:, enc["input_ids"].size(1):]
decoded = tok.batch_decode(gen_only, skip_special_tokens=True)
for ex, text in zip(batch_items, decoded):
pred = parse_pred(text)
gold = ex["final"]
ok = correct(pred, gold)
correct_n += int(ok)
total += 1
if len(examples_log) < 5:
examples_log.append({"text": text[:400], "pred": pred, "gold": gold, "ok": ok})
if total % 20 == 0:
print(f" progress {total}/{len(ds)} acc_so_far={correct_n/total:.3f} elapsed={time.time()-t0:.0f}s")
summary = {
"base_model": args.base_model,
"n": total,
"correct": correct_n,
"acc": correct_n / max(total, 1),
"max_new_tokens": args.max_new_tokens,
"temperature": args.temperature,
"prompt_template": COT_PROMPT_TEMPLATE,
"examples": examples_log,
}
Path(args.out).write_text(json.dumps(summary, indent=2))
print(f"[done] acc={summary['acc']:.4f} ({correct_n}/{total}) elapsed={time.time()-t0:.0f}s")
print(f"[written] {args.out}")
if __name__ == "__main__":
main()