File size: 5,053 Bytes
aedd6ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Stage 1 baseline eval: CRUXEval-O output prediction via full-trace generation.

Feed the training prompt (seeds frame 0), let the SFT model generate the trace,
take main()'s last return value as the predicted output, score by execution.
Greedy => pass@1 is the exact-match fraction. Reuses cwm_andre eval logic.
"""

import argparse
import json
import os
import subprocess
import sys
from datetime import timedelta

import torch
import torch.distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer

from data.dataset import _prompt_str
from data.sources import load_cruxeval
from tokens import add_trace_tokens, token_ids

ARG_SEP, FRAME_SEP, RETURN_SEP = "<|arg_sep|>", "<|frame_sep|>", "<|return_sep|>"


def extract_answer_trace_full(gen: str) -> str | None:
    """Value of main()'s last RETURN frame: ...<|arg_sep|>"value"<|frame_sep|>."""
    r = gen.rfind(RETURN_SEP)
    if r == -1:
        return None
    a = gen.find(ARG_SEP, r)
    if a == -1:
        return None
    rest = gen[a + len(ARG_SEP):]
    end = rest.find(FRAME_SEP)
    val = (rest[:end] if end != -1 else rest).strip()
    if not val:
        return None
    try:
        return json.loads(val)
    except json.JSONDecodeError:
        return val


def check_correct(code: str, expected: str, predicted: str, timeout: float = 3.0) -> bool:
    """Execute `code; assert expected == predicted` (CRUXEval semantics)."""
    test = f"{code}\nassert {expected} == {predicted}"
    try:
        return subprocess.run(
            [sys.executable, "-c", test], timeout=timeout, capture_output=True
        ).returncode == 0
    except Exception:
        return False


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", required=True)
    ap.add_argument("--n_samples", type=int, default=-1)
    ap.add_argument("--max_new_tokens", type=int, default=8192)
    ap.add_argument("--batch_size", type=int, default=8)
    ap.add_argument("--out", default="")
    args = ap.parse_args()

    # DDP-style data parallelism for inference: torchrun sets RANK/WORLD_SIZE/LOCAL_RANK.
    ddp = "RANK" in os.environ
    rank = int(os.environ.get("RANK", 0))
    world = int(os.environ.get("WORLD_SIZE", 1))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    if ddp:
        dist.init_process_group("nccl", timeout=timedelta(hours=1))  # ranks finish at different times under long gens
    torch.cuda.set_device(local_rank)

    tok = AutoTokenizer.from_pretrained(args.model, use_fast=True)
    add_trace_tokens(tok)  # idempotent; ensures trace tokens present
    tok.padding_side = "left"  # left-pad so all generated tokens start at the same offset
    eot_id = token_ids(tok)["<|end_of_text|>"]
    model = AutoModelForCausalLM.from_pretrained(
        args.model, torch_dtype=torch.bfloat16).to(local_rank).eval()

    rows = load_cruxeval()
    if args.n_samples > 0:
        rows = rows[: args.n_samples]
    n = len(rows)
    shard = rows[rank::world]  # disjoint round-robin split across ranks

    n_correct = n_fmt = 0
    results = []
    for bi, batch_start in enumerate(range(0, len(shard), args.batch_size)):
        batch = shard[batch_start: batch_start + args.batch_size]
        enc = tok([_prompt_str(r["code"], r["input"]) for r in batch],
                  return_tensors="pt", padding=True, add_special_tokens=False).to(local_rank)
        with torch.no_grad():
            out = model.generate(**enc, max_new_tokens=args.max_new_tokens, do_sample=False,
                                 eos_token_id=eot_id, pad_token_id=eot_id)
        for j, r in enumerate(batch):
            gen = tok.decode(out[j, enc["input_ids"].shape[1]:], skip_special_tokens=False)
            pred = extract_answer_trace_full(gen)
            ok = pred is not None and check_correct(r["code"], r["output"], pred)
            n_fmt += pred is not None
            n_correct += ok
            results.append({"id": r["id"], "expected": r["output"], "predicted": pred, "correct": ok, "generation": gen})
        if rank == 0 and (bi + 1) % 5 == 0:
            done = batch_start + len(batch)
            print(f"  rank0 {done}/{len(shard)}  pass@1={n_correct/done:.4f}", flush=True)

    # Reduce metrics and gather per-row results across ranks.
    if ddp:
        t = torch.tensor([n_correct, n_fmt], device=local_rank)
        dist.all_reduce(t)
        n_correct, n_fmt = int(t[0]), int(t[1])
        gathered = [None] * world
        dist.gather_object(results, gathered if rank == 0 else None, dst=0)
        if rank == 0:
            results = [x for part in gathered for x in part]

    if rank == 0:
        print(f"\nCRUXEval-O pass@1={n_correct / n:.4f}  "
              f"valid_format={n_fmt / n:.4f}  (n={n}, greedy)")
        if args.out:
            with open(args.out, "w") as f:
                json.dump({"pass_at_1": n_correct / n, "valid_format": n_fmt / n,
                           "n": n, "results": results}, f, indent=2)
    if ddp:
        dist.destroy_process_group()


if __name__ == "__main__":
    main()