File size: 4,459 Bytes
bc7101b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""E1: base-model ceiling eval — Qwen2.5-Math-7B-Instruct + standard CoT
prompting on GSM8K-test n=200 AR. No latent loop, no projector, no LoRA, no
bottleneck. Establishes the verbal-CoT upper bound our BLT-Reasoner is
trying to approach.

Uses the SAME 200 test problems and same accuracy parsing as our BLT eval
for a clean side-by-side comparison.
"""
from __future__ import annotations

import argparse
import json
import re
import time
from pathlib import Path
from typing import Optional

import torch
from torch.utils.data import DataLoader

from ..data import GSM8KDataset, collate_batch


COT_PROMPT_TEMPLATE = (
    "Please solve the following math problem step by step. End your response "
    "with '#### N' where N is the final numerical answer.\n\n"
    "Question: {question}\nAnswer:"
)


GSM8K_NUM = re.compile(r"####\s*(-?\d+(?:\.\d+)?)")
ANY_NUM = re.compile(r"-?\d+(?:\.\d+)?")


def parse_pred(text: str) -> Optional[str]:
    m = GSM8K_NUM.search(text)
    if m:
        return m.group(1)
    nums = ANY_NUM.findall(text)
    return nums[-1] if nums else None


def correct(pred: Optional[str], gold: str) -> bool:
    if pred is None:
        return False
    try:
        return abs(float(pred) - float(gold)) < 1e-4
    except ValueError:
        return False


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--base_model", default="Qwen/Qwen2.5-Math-7B-Instruct")
    p.add_argument("--n", type=int, default=200)
    p.add_argument("--max_new_tokens", type=int, default=384)
    p.add_argument("--temperature", type=float, default=0.0)
    p.add_argument("--out", default="/home/ubuntu/work/base_qwen_gsm8k.json")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    from transformers import AutoModelForCausalLM, AutoTokenizer
    tok = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    print(f"[load] {args.base_model}")
    model = AutoModelForCausalLM.from_pretrained(
        args.base_model, torch_dtype=torch.bfloat16,
        attn_implementation="eager", trust_remote_code=True,
    ).to(device).eval()
    model.config.use_cache = True
    print("[load] done")

    ds = GSM8KDataset(split="test", max_examples=args.n)
    print(f"[data] n={len(ds)}")

    correct_n = 0
    total = 0
    examples_log = []
    t0 = time.time()

    # Batched generation (batch=4 for memory safety with max_new=384)
    bs = 4
    for i in range(0, len(ds), bs):
        batch_items = [ds[j] for j in range(i, min(i + bs, len(ds)))]
        prompts = [COT_PROMPT_TEMPLATE.format(question=ex["question"]) for ex in batch_items]
        enc = tok(prompts, return_tensors="pt", padding=True, padding_side="left",
                  truncation=True, max_length=512).to(device)
        with torch.no_grad():
            gen = model.generate(
                input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
                max_new_tokens=args.max_new_tokens,
                do_sample=False if args.temperature <= 0 else True,
                temperature=max(args.temperature, 1e-6),
                pad_token_id=tok.pad_token_id,
                eos_token_id=tok.eos_token_id,
            )
        gen_only = gen[:, enc["input_ids"].size(1):]
        decoded = tok.batch_decode(gen_only, skip_special_tokens=True)
        for ex, text in zip(batch_items, decoded):
            pred = parse_pred(text)
            gold = ex["final"]
            ok = correct(pred, gold)
            correct_n += int(ok)
            total += 1
            if len(examples_log) < 5:
                examples_log.append({"text": text[:400], "pred": pred, "gold": gold, "ok": ok})
        if total % 20 == 0:
            print(f"  progress {total}/{len(ds)}  acc_so_far={correct_n/total:.3f}  elapsed={time.time()-t0:.0f}s")

    summary = {
        "base_model": args.base_model,
        "n": total,
        "correct": correct_n,
        "acc": correct_n / max(total, 1),
        "max_new_tokens": args.max_new_tokens,
        "temperature": args.temperature,
        "prompt_template": COT_PROMPT_TEMPLATE,
        "examples": examples_log,
    }
    Path(args.out).write_text(json.dumps(summary, indent=2))
    print(f"[done] acc={summary['acc']:.4f} ({correct_n}/{total})  elapsed={time.time()-t0:.0f}s")
    print(f"[written] {args.out}")


if __name__ == "__main__":
    main()