curriculum-cot-code / _runs /eval_strawman_cellpolicy.py
Avra98's picture
Add data/ JSONLs + _runs/ launch scripts (override .gitignore)
48c96cf verified
#!/usr/bin/env python3
"""Re-evaluate any strawman / adaptive-k checkpoint using the cell-policy metric.
This is a thin CLI wrapper that:
1. Loads a base model + LoRA adapter.
2. Runs the same scoring procedure as
``multi_output_cell_policy/sft_multi_output_train.py::run_eval``,
i.e. for each puzzle it uses ``build_cell_examples_from_row`` to iterate
over empty cells in row-major order and scores each predicted value
with ``score_prediction_text`` against the i-consistent target set at
``--stage_i`` (default 3, matching the S3 eval reported in the rebuttal).
3. The only difference vs the cell-policy is that the model emits the whole
puzzle in ONE forward pass, then the predicted list is split into
per-cell singletons.
Use ``--kind strawman`` for vanilla LoRA models (``simple_baseline_sudoku_train.py``)
and ``--kind adaptive_k --num_cot_tokens K`` for recurrent-hidden adaptive-k
models (``adaptive_latent_baseline_sudoku_train.py``).
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, List
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from multi_output_cell_policy.sft_multi_output_train import ( # type: ignore # noqa: E402
load_jsonl_rows,
pick_dtype,
)
from _runs.simple_baseline_sudoku_train import ( # type: ignore # noqa: E402
run_eval as run_eval_strawman,
)
from _runs.adaptive_latent_baseline_sudoku_train import ( # type: ignore # noqa: E402
run_eval as run_eval_adaptive_k,
)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--kind", choices=["strawman", "adaptive_k"], required=True)
p.add_argument("--model_name", default="Qwen/Qwen2.5-1.5B-Instruct")
p.add_argument("--adapter_dir", required=True)
p.add_argument("--eval_jsonl", required=True)
p.add_argument("--cache_dir", default=str(ROOT / ".hf_cache"))
p.add_argument("--eval_rows", type=int, default=100)
p.add_argument("--max_completion_length", type=int, default=96)
p.add_argument("--stage_i", type=int, default=3)
p.add_argument(
"--num_cot_tokens",
type=int,
default=0,
help="Only used when --kind adaptive_k.",
)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--out_json", default="")
return p.parse_args()
def main() -> None:
args = parse_args()
set_seed(int(args.seed))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = pick_dtype()
print(f"[eval-cellpolicy] kind={args.kind} adapter={args.adapter_dir}", flush=True)
print(f"[eval-cellpolicy] eval_jsonl={args.eval_jsonl} stage_i={args.stage_i}", flush=True)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name, cache_dir=args.cache_dir, use_fast=True
)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
base = AutoModelForCausalLM.from_pretrained(
args.model_name, cache_dir=args.cache_dir, torch_dtype=dtype
)
model = PeftModel.from_pretrained(base, args.adapter_dir)
model.to(device)
model.eval()
rows: List[Dict[str, Any]] = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows))
print(f"[eval-cellpolicy] loaded {len(rows)} eval rows", flush=True)
if args.kind == "strawman":
metrics = run_eval_strawman(
model, tokenizer, rows, device,
max_new_tokens=int(args.max_completion_length),
print_n=3,
stage_i=int(args.stage_i),
)
else:
metrics = run_eval_adaptive_k(
model, tokenizer, rows, device,
num_cot_tokens=int(args.num_cot_tokens),
max_new_tokens=int(args.max_completion_length),
print_n=3,
stage_i=int(args.stage_i),
)
print("[eval-cellpolicy] metrics:", json.dumps(metrics, indent=2), flush=True)
if args.out_json:
Path(args.out_json).parent.mkdir(parents=True, exist_ok=True)
with open(args.out_json, "w") as f:
json.dump(
{
"kind": args.kind,
"adapter_dir": args.adapter_dir,
"eval_jsonl": args.eval_jsonl,
"stage_i": int(args.stage_i),
"num_cot_tokens": int(args.num_cot_tokens),
"metrics": metrics,
},
f,
indent=2,
)
print(f"[eval-cellpolicy] wrote {args.out_json}", flush=True)
if __name__ == "__main__":
main()