File size: 4,840 Bytes
48c96cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python3
"""Re-evaluate any strawman / adaptive-k checkpoint using the cell-policy metric.

This is a thin CLI wrapper that:

  1. Loads a base model + LoRA adapter.
  2. Runs the same scoring procedure as
     ``multi_output_cell_policy/sft_multi_output_train.py::run_eval``,
     i.e. for each puzzle it uses ``build_cell_examples_from_row`` to iterate
     over empty cells in row-major order and scores each predicted value
     with ``score_prediction_text`` against the i-consistent target set at
     ``--stage_i`` (default 3, matching the S3 eval reported in the rebuttal).
  3. The only difference vs the cell-policy is that the model emits the whole
     puzzle in ONE forward pass, then the predicted list is split into
     per-cell singletons.

Use ``--kind strawman`` for vanilla LoRA models (``simple_baseline_sudoku_train.py``)
and ``--kind adaptive_k --num_cot_tokens K`` for recurrent-hidden adaptive-k
models (``adaptive_latent_baseline_sudoku_train.py``).
"""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, List

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed

ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from multi_output_cell_policy.sft_multi_output_train import (  # type: ignore  # noqa: E402
    load_jsonl_rows,
    pick_dtype,
)
from _runs.simple_baseline_sudoku_train import (  # type: ignore  # noqa: E402
    run_eval as run_eval_strawman,
)
from _runs.adaptive_latent_baseline_sudoku_train import (  # type: ignore  # noqa: E402
    run_eval as run_eval_adaptive_k,
)


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser()
    p.add_argument("--kind", choices=["strawman", "adaptive_k"], required=True)
    p.add_argument("--model_name", default="Qwen/Qwen2.5-1.5B-Instruct")
    p.add_argument("--adapter_dir", required=True)
    p.add_argument("--eval_jsonl", required=True)
    p.add_argument("--cache_dir", default=str(ROOT / ".hf_cache"))
    p.add_argument("--eval_rows", type=int, default=100)
    p.add_argument("--max_completion_length", type=int, default=96)
    p.add_argument("--stage_i", type=int, default=3)
    p.add_argument(
        "--num_cot_tokens",
        type=int,
        default=0,
        help="Only used when --kind adaptive_k.",
    )
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--out_json", default="")
    return p.parse_args()


def main() -> None:
    args = parse_args()
    set_seed(int(args.seed))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = pick_dtype()

    print(f"[eval-cellpolicy] kind={args.kind}  adapter={args.adapter_dir}", flush=True)
    print(f"[eval-cellpolicy] eval_jsonl={args.eval_jsonl}  stage_i={args.stage_i}", flush=True)

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name, cache_dir=args.cache_dir, use_fast=True
    )
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token

    base = AutoModelForCausalLM.from_pretrained(
        args.model_name, cache_dir=args.cache_dir, torch_dtype=dtype
    )
    model = PeftModel.from_pretrained(base, args.adapter_dir)
    model.to(device)
    model.eval()

    rows: List[Dict[str, Any]] = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows))
    print(f"[eval-cellpolicy] loaded {len(rows)} eval rows", flush=True)

    if args.kind == "strawman":
        metrics = run_eval_strawman(
            model, tokenizer, rows, device,
            max_new_tokens=int(args.max_completion_length),
            print_n=3,
            stage_i=int(args.stage_i),
        )
    else:
        metrics = run_eval_adaptive_k(
            model, tokenizer, rows, device,
            num_cot_tokens=int(args.num_cot_tokens),
            max_new_tokens=int(args.max_completion_length),
            print_n=3,
            stage_i=int(args.stage_i),
        )

    print("[eval-cellpolicy] metrics:", json.dumps(metrics, indent=2), flush=True)
    if args.out_json:
        Path(args.out_json).parent.mkdir(parents=True, exist_ok=True)
        with open(args.out_json, "w") as f:
            json.dump(
                {
                    "kind": args.kind,
                    "adapter_dir": args.adapter_dir,
                    "eval_jsonl": args.eval_jsonl,
                    "stage_i": int(args.stage_i),
                    "num_cot_tokens": int(args.num_cot_tokens),
                    "metrics": metrics,
                },
                f,
                indent=2,
            )
        print(f"[eval-cellpolicy] wrote {args.out_json}", flush=True)


if __name__ == "__main__":
    main()