"""E1: base-model ceiling eval — Qwen2.5-Math-7B-Instruct + standard CoT prompting on GSM8K-test n=200 AR. No latent loop, no projector, no LoRA, no bottleneck. Establishes the verbal-CoT upper bound our BLT-Reasoner is trying to approach. Uses the SAME 200 test problems and same accuracy parsing as our BLT eval for a clean side-by-side comparison. """ from __future__ import annotations import argparse import json import re import time from pathlib import Path from typing import Optional import torch from torch.utils.data import DataLoader from ..data import GSM8KDataset, collate_batch COT_PROMPT_TEMPLATE = ( "Please solve the following math problem step by step. End your response " "with '#### N' where N is the final numerical answer.\n\n" "Question: {question}\nAnswer:" ) GSM8K_NUM = re.compile(r"####\s*(-?\d+(?:\.\d+)?)") ANY_NUM = re.compile(r"-?\d+(?:\.\d+)?") def parse_pred(text: str) -> Optional[str]: m = GSM8K_NUM.search(text) if m: return m.group(1) nums = ANY_NUM.findall(text) return nums[-1] if nums else None def correct(pred: Optional[str], gold: str) -> bool: if pred is None: return False try: return abs(float(pred) - float(gold)) < 1e-4 except ValueError: return False def main(): p = argparse.ArgumentParser() p.add_argument("--base_model", default="Qwen/Qwen2.5-Math-7B-Instruct") p.add_argument("--n", type=int, default=200) p.add_argument("--max_new_tokens", type=int, default=384) p.add_argument("--temperature", type=float, default=0.0) p.add_argument("--out", default="/home/ubuntu/work/base_qwen_gsm8k.json") args = p.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") from transformers import AutoModelForCausalLM, AutoTokenizer tok = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True) if tok.pad_token is None: tok.pad_token = tok.eos_token print(f"[load] {args.base_model}") model = AutoModelForCausalLM.from_pretrained( args.base_model, torch_dtype=torch.bfloat16, attn_implementation="eager", trust_remote_code=True, ).to(device).eval() model.config.use_cache = True print("[load] done") ds = GSM8KDataset(split="test", max_examples=args.n) print(f"[data] n={len(ds)}") correct_n = 0 total = 0 examples_log = [] t0 = time.time() # Batched generation (batch=4 for memory safety with max_new=384) bs = 4 for i in range(0, len(ds), bs): batch_items = [ds[j] for j in range(i, min(i + bs, len(ds)))] prompts = [COT_PROMPT_TEMPLATE.format(question=ex["question"]) for ex in batch_items] enc = tok(prompts, return_tensors="pt", padding=True, padding_side="left", truncation=True, max_length=512).to(device) with torch.no_grad(): gen = model.generate( input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], max_new_tokens=args.max_new_tokens, do_sample=False if args.temperature <= 0 else True, temperature=max(args.temperature, 1e-6), pad_token_id=tok.pad_token_id, eos_token_id=tok.eos_token_id, ) gen_only = gen[:, enc["input_ids"].size(1):] decoded = tok.batch_decode(gen_only, skip_special_tokens=True) for ex, text in zip(batch_items, decoded): pred = parse_pred(text) gold = ex["final"] ok = correct(pred, gold) correct_n += int(ok) total += 1 if len(examples_log) < 5: examples_log.append({"text": text[:400], "pred": pred, "gold": gold, "ok": ok}) if total % 20 == 0: print(f" progress {total}/{len(ds)} acc_so_far={correct_n/total:.3f} elapsed={time.time()-t0:.0f}s") summary = { "base_model": args.base_model, "n": total, "correct": correct_n, "acc": correct_n / max(total, 1), "max_new_tokens": args.max_new_tokens, "temperature": args.temperature, "prompt_template": COT_PROMPT_TEMPLATE, "examples": examples_log, } Path(args.out).write_text(json.dumps(summary, indent=2)) print(f"[done] acc={summary['acc']:.4f} ({correct_n}/{total}) elapsed={time.time()-t0:.0f}s") print(f"[written] {args.out}") if __name__ == "__main__": main()