| |
| """Re-evaluate any strawman / adaptive-k checkpoint using the cell-policy metric. |
| |
| This is a thin CLI wrapper that: |
| |
| 1. Loads a base model + LoRA adapter. |
| 2. Runs the same scoring procedure as |
| ``multi_output_cell_policy/sft_multi_output_train.py::run_eval``, |
| i.e. for each puzzle it uses ``build_cell_examples_from_row`` to iterate |
| over empty cells in row-major order and scores each predicted value |
| with ``score_prediction_text`` against the i-consistent target set at |
| ``--stage_i`` (default 3, matching the S3 eval reported in the rebuttal). |
| 3. The only difference vs the cell-policy is that the model emits the whole |
| puzzle in ONE forward pass, then the predicted list is split into |
| per-cell singletons. |
| |
| Use ``--kind strawman`` for vanilla LoRA models (``simple_baseline_sudoku_train.py``) |
| and ``--kind adaptive_k --num_cot_tokens K`` for recurrent-hidden adaptive-k |
| models (``adaptive_latent_baseline_sudoku_train.py``). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
| from typing import Any, Dict, List |
|
|
| import torch |
| from peft import PeftModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| from multi_output_cell_policy.sft_multi_output_train import ( |
| load_jsonl_rows, |
| pick_dtype, |
| ) |
| from _runs.simple_baseline_sudoku_train import ( |
| run_eval as run_eval_strawman, |
| ) |
| from _runs.adaptive_latent_baseline_sudoku_train import ( |
| run_eval as run_eval_adaptive_k, |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser() |
| p.add_argument("--kind", choices=["strawman", "adaptive_k"], required=True) |
| p.add_argument("--model_name", default="Qwen/Qwen2.5-1.5B-Instruct") |
| p.add_argument("--adapter_dir", required=True) |
| p.add_argument("--eval_jsonl", required=True) |
| p.add_argument("--cache_dir", default=str(ROOT / ".hf_cache")) |
| p.add_argument("--eval_rows", type=int, default=100) |
| p.add_argument("--max_completion_length", type=int, default=96) |
| p.add_argument("--stage_i", type=int, default=3) |
| p.add_argument( |
| "--num_cot_tokens", |
| type=int, |
| default=0, |
| help="Only used when --kind adaptive_k.", |
| ) |
| p.add_argument("--seed", type=int, default=0) |
| p.add_argument("--out_json", default="") |
| return p.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| set_seed(int(args.seed)) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| dtype = pick_dtype() |
|
|
| print(f"[eval-cellpolicy] kind={args.kind} adapter={args.adapter_dir}", flush=True) |
| print(f"[eval-cellpolicy] eval_jsonl={args.eval_jsonl} stage_i={args.stage_i}", flush=True) |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| args.model_name, cache_dir=args.cache_dir, use_fast=True |
| ) |
| if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| base = AutoModelForCausalLM.from_pretrained( |
| args.model_name, cache_dir=args.cache_dir, torch_dtype=dtype |
| ) |
| model = PeftModel.from_pretrained(base, args.adapter_dir) |
| model.to(device) |
| model.eval() |
|
|
| rows: List[Dict[str, Any]] = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows)) |
| print(f"[eval-cellpolicy] loaded {len(rows)} eval rows", flush=True) |
|
|
| if args.kind == "strawman": |
| metrics = run_eval_strawman( |
| model, tokenizer, rows, device, |
| max_new_tokens=int(args.max_completion_length), |
| print_n=3, |
| stage_i=int(args.stage_i), |
| ) |
| else: |
| metrics = run_eval_adaptive_k( |
| model, tokenizer, rows, device, |
| num_cot_tokens=int(args.num_cot_tokens), |
| max_new_tokens=int(args.max_completion_length), |
| print_n=3, |
| stage_i=int(args.stage_i), |
| ) |
|
|
| print("[eval-cellpolicy] metrics:", json.dumps(metrics, indent=2), flush=True) |
| if args.out_json: |
| Path(args.out_json).parent.mkdir(parents=True, exist_ok=True) |
| with open(args.out_json, "w") as f: |
| json.dump( |
| { |
| "kind": args.kind, |
| "adapter_dir": args.adapter_dir, |
| "eval_jsonl": args.eval_jsonl, |
| "stage_i": int(args.stage_i), |
| "num_cot_tokens": int(args.num_cot_tokens), |
| "metrics": metrics, |
| }, |
| f, |
| indent=2, |
| ) |
| print(f"[eval-cellpolicy] wrote {args.out_json}", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|