#!/usr/bin/env python3 """Re-evaluate any strawman / adaptive-k checkpoint using the cell-policy metric. This is a thin CLI wrapper that: 1. Loads a base model + LoRA adapter. 2. Runs the same scoring procedure as ``multi_output_cell_policy/sft_multi_output_train.py::run_eval``, i.e. for each puzzle it uses ``build_cell_examples_from_row`` to iterate over empty cells in row-major order and scores each predicted value with ``score_prediction_text`` against the i-consistent target set at ``--stage_i`` (default 3, matching the S3 eval reported in the rebuttal). 3. The only difference vs the cell-policy is that the model emits the whole puzzle in ONE forward pass, then the predicted list is split into per-cell singletons. Use ``--kind strawman`` for vanilla LoRA models (``simple_baseline_sudoku_train.py``) and ``--kind adaptive_k --num_cot_tokens K`` for recurrent-hidden adaptive-k models (``adaptive_latent_baseline_sudoku_train.py``). """ from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Any, Dict, List import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed ROOT = Path(__file__).resolve().parent.parent if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from multi_output_cell_policy.sft_multi_output_train import ( # type: ignore # noqa: E402 load_jsonl_rows, pick_dtype, ) from _runs.simple_baseline_sudoku_train import ( # type: ignore # noqa: E402 run_eval as run_eval_strawman, ) from _runs.adaptive_latent_baseline_sudoku_train import ( # type: ignore # noqa: E402 run_eval as run_eval_adaptive_k, ) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser() p.add_argument("--kind", choices=["strawman", "adaptive_k"], required=True) p.add_argument("--model_name", default="Qwen/Qwen2.5-1.5B-Instruct") p.add_argument("--adapter_dir", required=True) p.add_argument("--eval_jsonl", required=True) p.add_argument("--cache_dir", default=str(ROOT / ".hf_cache")) p.add_argument("--eval_rows", type=int, default=100) p.add_argument("--max_completion_length", type=int, default=96) p.add_argument("--stage_i", type=int, default=3) p.add_argument( "--num_cot_tokens", type=int, default=0, help="Only used when --kind adaptive_k.", ) p.add_argument("--seed", type=int, default=0) p.add_argument("--out_json", default="") return p.parse_args() def main() -> None: args = parse_args() set_seed(int(args.seed)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = pick_dtype() print(f"[eval-cellpolicy] kind={args.kind} adapter={args.adapter_dir}", flush=True) print(f"[eval-cellpolicy] eval_jsonl={args.eval_jsonl} stage_i={args.stage_i}", flush=True) tokenizer = AutoTokenizer.from_pretrained( args.model_name, cache_dir=args.cache_dir, use_fast=True ) if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: tokenizer.pad_token = tokenizer.eos_token base = AutoModelForCausalLM.from_pretrained( args.model_name, cache_dir=args.cache_dir, torch_dtype=dtype ) model = PeftModel.from_pretrained(base, args.adapter_dir) model.to(device) model.eval() rows: List[Dict[str, Any]] = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows)) print(f"[eval-cellpolicy] loaded {len(rows)} eval rows", flush=True) if args.kind == "strawman": metrics = run_eval_strawman( model, tokenizer, rows, device, max_new_tokens=int(args.max_completion_length), print_n=3, stage_i=int(args.stage_i), ) else: metrics = run_eval_adaptive_k( model, tokenizer, rows, device, num_cot_tokens=int(args.num_cot_tokens), max_new_tokens=int(args.max_completion_length), print_n=3, stage_i=int(args.stage_i), ) print("[eval-cellpolicy] metrics:", json.dumps(metrics, indent=2), flush=True) if args.out_json: Path(args.out_json).parent.mkdir(parents=True, exist_ok=True) with open(args.out_json, "w") as f: json.dump( { "kind": args.kind, "adapter_dir": args.adapter_dir, "eval_jsonl": args.eval_jsonl, "stage_i": int(args.stage_i), "num_cot_tokens": int(args.num_cot_tokens), "metrics": metrics, }, f, indent=2, ) print(f"[eval-cellpolicy] wrote {args.out_json}", flush=True) if __name__ == "__main__": main()