| """E1: base-model ceiling eval — Qwen2.5-Math-7B-Instruct + standard CoT |
| prompting on GSM8K-test n=200 AR. No latent loop, no projector, no LoRA, no |
| bottleneck. Establishes the verbal-CoT upper bound our BLT-Reasoner is |
| trying to approach. |
| |
| Uses the SAME 200 test problems and same accuracy parsing as our BLT eval |
| for a clean side-by-side comparison. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import re |
| import time |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| from torch.utils.data import DataLoader |
|
|
| from ..data import GSM8KDataset, collate_batch |
|
|
|
|
| COT_PROMPT_TEMPLATE = ( |
| "Please solve the following math problem step by step. End your response " |
| "with '#### N' where N is the final numerical answer.\n\n" |
| "Question: {question}\nAnswer:" |
| ) |
|
|
|
|
| GSM8K_NUM = re.compile(r"####\s*(-?\d+(?:\.\d+)?)") |
| ANY_NUM = re.compile(r"-?\d+(?:\.\d+)?") |
|
|
|
|
| def parse_pred(text: str) -> Optional[str]: |
| m = GSM8K_NUM.search(text) |
| if m: |
| return m.group(1) |
| nums = ANY_NUM.findall(text) |
| return nums[-1] if nums else None |
|
|
|
|
| def correct(pred: Optional[str], gold: str) -> bool: |
| if pred is None: |
| return False |
| try: |
| return abs(float(pred) - float(gold)) < 1e-4 |
| except ValueError: |
| return False |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--base_model", default="Qwen/Qwen2.5-Math-7B-Instruct") |
| p.add_argument("--n", type=int, default=200) |
| p.add_argument("--max_new_tokens", type=int, default=384) |
| p.add_argument("--temperature", type=float, default=0.0) |
| p.add_argument("--out", default="/home/ubuntu/work/base_qwen_gsm8k.json") |
| args = p.parse_args() |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| tok = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| print(f"[load] {args.base_model}") |
| model = AutoModelForCausalLM.from_pretrained( |
| args.base_model, torch_dtype=torch.bfloat16, |
| attn_implementation="eager", trust_remote_code=True, |
| ).to(device).eval() |
| model.config.use_cache = True |
| print("[load] done") |
|
|
| ds = GSM8KDataset(split="test", max_examples=args.n) |
| print(f"[data] n={len(ds)}") |
|
|
| correct_n = 0 |
| total = 0 |
| examples_log = [] |
| t0 = time.time() |
|
|
| |
| bs = 4 |
| for i in range(0, len(ds), bs): |
| batch_items = [ds[j] for j in range(i, min(i + bs, len(ds)))] |
| prompts = [COT_PROMPT_TEMPLATE.format(question=ex["question"]) for ex in batch_items] |
| enc = tok(prompts, return_tensors="pt", padding=True, padding_side="left", |
| truncation=True, max_length=512).to(device) |
| with torch.no_grad(): |
| gen = model.generate( |
| input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], |
| max_new_tokens=args.max_new_tokens, |
| do_sample=False if args.temperature <= 0 else True, |
| temperature=max(args.temperature, 1e-6), |
| pad_token_id=tok.pad_token_id, |
| eos_token_id=tok.eos_token_id, |
| ) |
| gen_only = gen[:, enc["input_ids"].size(1):] |
| decoded = tok.batch_decode(gen_only, skip_special_tokens=True) |
| for ex, text in zip(batch_items, decoded): |
| pred = parse_pred(text) |
| gold = ex["final"] |
| ok = correct(pred, gold) |
| correct_n += int(ok) |
| total += 1 |
| if len(examples_log) < 5: |
| examples_log.append({"text": text[:400], "pred": pred, "gold": gold, "ok": ok}) |
| if total % 20 == 0: |
| print(f" progress {total}/{len(ds)} acc_so_far={correct_n/total:.3f} elapsed={time.time()-t0:.0f}s") |
|
|
| summary = { |
| "base_model": args.base_model, |
| "n": total, |
| "correct": correct_n, |
| "acc": correct_n / max(total, 1), |
| "max_new_tokens": args.max_new_tokens, |
| "temperature": args.temperature, |
| "prompt_template": COT_PROMPT_TEMPLATE, |
| "examples": examples_log, |
| } |
| Path(args.out).write_text(json.dumps(summary, indent=2)) |
| print(f"[done] acc={summary['acc']:.4f} ({correct_n}/{total}) elapsed={time.time()-t0:.0f}s") |
| print(f"[written] {args.out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|