#!/usr/bin/env python3 """Self-consistency evaluation for math-conjecture model checkpoints.""" from __future__ import annotations import argparse import json import re from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple import torch import yaml from datasets import load_dataset from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed SCRIPT_ROOT = Path(__file__).resolve().parents[1] DEFAULT_CONFIG_PATH = SCRIPT_ROOT / "configs" / "deepseek_math_sota.yaml" DEFAULT_OUTPUT_JSON = SCRIPT_ROOT / "runs" / "latest_eval_report.json" BOXED_RE = re.compile(r"\\boxed\{([^{}]+)\}") SPACE_RE = re.compile(r"\s+") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run pass@k-style evaluation on held-out split.") parser.add_argument( "--config", type=Path, default=DEFAULT_CONFIG_PATH, help="Training config used for prompt formatting defaults.", ) parser.add_argument( "--base-model", type=str, default=None, help="Override base model id from config.", ) parser.add_argument( "--adapter-path", type=Path, default=None, help="Optional LoRA adapter path to load on top of base model.", ) parser.add_argument( "--eval-file", type=Path, default=None, help="Parquet split used for evaluation (defaults to post_eval.eval_file or data.default_validation_file).", ) parser.add_argument("--max-samples", type=int, default=300, help="Maximum evaluation rows.") parser.add_argument("--k", type=int, default=4, help="Number of sampled generations per prompt.") parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length cap.") parser.add_argument("--max-input-length", type=int, default=4096, help="Prompt tokenization length cap.") parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.") parser.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling p.") parser.add_argument("--seed", type=int, default=17, help="Random seed.") parser.add_argument( "--progress-every", type=int, default=25, help="Print progress every N evaluated rows (0 disables).", ) parser.add_argument( "--sample-records", type=int, default=30, help="How many sample records to store in report.", ) parser.add_argument( "--output-json", type=Path, default=DEFAULT_OUTPUT_JSON, help="Where to write evaluation report.", ) return parser.parse_args() def as_text(value: Any) -> str: if value is None: return "" if isinstance(value, str): return value.strip() return str(value).strip() def as_float(value: Any, default: float) -> float: if value is None: return default try: return float(value) except (TypeError, ValueError): return default def as_int(value: Any, default: int) -> int: if value is None: return default try: return int(value) except (TypeError, ValueError): return default def load_config(path: Path) -> Dict[str, Any]: cfg = yaml.safe_load(path.read_text(encoding="utf-8")) if not isinstance(cfg, dict): raise ValueError("Invalid YAML config.") return cfg def normalize_answer(text: str) -> str: text = text.strip().lower() text = text.replace("$", "") text = text.replace("\\left", "").replace("\\right", "") text = text.replace("\\,", "").replace("\\!", "").replace("\\;", "") text = SPACE_RE.sub(" ", text) return text.strip(" .") def extract_boxed_values(text: str) -> List[str]: return [normalize_answer(match) for match in BOXED_RE.findall(text or "") if normalize_answer(match)] def parse_numeric_value(text: str) -> Optional[float]: normalized = normalize_answer(text) if not normalized: return None candidate = normalized.replace(",", "") if re.fullmatch(r"[-+]?\d+\s*/\s*[-+]?\d+", candidate): left, right = candidate.split("/", maxsplit=1) try: numerator = float(left.strip()) denominator = float(right.strip()) except ValueError: return None if denominator == 0: return None return numerator / denominator if re.fullmatch(r"[-+]?(?:\d+\.\d*|\d*\.\d+|\d+)(?:[eE][-+]?\d+)?", candidate): try: return float(candidate) except ValueError: return None return None def approximately_equal(left: float, right: float) -> bool: tolerance = 1e-6 * max(1.0, abs(left), abs(right)) return abs(left - right) <= tolerance def match_candidate(candidate: str, expected_values: Sequence[str]) -> Dict[str, Any]: cand_norm = normalize_answer(candidate) if not cand_norm: return { "match": False, "exact": False, "boxed": False, "numeric": False, "reason": "empty_candidate", } cand_boxed = extract_boxed_values(candidate) cand_num = parse_numeric_value(cand_norm) substring_hit = False boxed_hit = False numeric_hit = False for expected in expected_values: exp_norm = normalize_answer(expected) if not exp_norm: continue if cand_norm == exp_norm: return { "match": True, "exact": True, "boxed": exp_norm in cand_boxed, "numeric": False, "reason": "exact", } if exp_norm in cand_norm or cand_norm in exp_norm: substring_hit = True expected_boxed = extract_boxed_values(expected) for cand_box in cand_boxed: if cand_box == exp_norm or exp_norm in cand_box or cand_box in exp_norm: boxed_hit = True for exp_box in expected_boxed: if cand_norm == exp_box or exp_box in cand_norm or cand_norm in exp_box: boxed_hit = True exp_num = parse_numeric_value(exp_norm) if cand_num is not None and exp_num is not None and approximately_equal(cand_num, exp_num): numeric_hit = True if boxed_hit: return { "match": True, "exact": False, "boxed": True, "numeric": numeric_hit, "reason": "boxed", } if numeric_hit: return { "match": True, "exact": False, "boxed": False, "numeric": True, "reason": "numeric", } if substring_hit: return { "match": True, "exact": False, "boxed": False, "numeric": False, "reason": "substring", } return { "match": False, "exact": False, "boxed": False, "numeric": False, "reason": "no_match", } def flatten_expected(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> List[str]: out: List[str] = [] final_field = as_text(data_cfg.get("final_answer_field")) or "final_answer" target_field = as_text(data_cfg.get("target_field")) or "target" final_answer = row.get(final_field) if final_answer is not None: txt = as_text(final_answer) if txt: out.append(txt) target = row.get(target_field) if target is None: return out if isinstance(target, str): stripped = target.strip() if not stripped: return out try: target = json.loads(stripped) except json.JSONDecodeError: out.append(stripped) return out if isinstance(target, dict): for value in target.values(): if isinstance(value, list): for item in value: txt = as_text(item) if txt: out.append(txt) else: txt = as_text(value) if txt: out.append(txt) elif isinstance(target, list): for item in target: txt = as_text(item) if txt: out.append(txt) else: txt = as_text(target) if txt: out.append(txt) return out def build_user_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str: prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt" prompt = as_text(row.get(prompt_field)) if not prompt: prompt = "Solve the math task." meta_fields = [ ("task_type", "Task type"), ("family", "Family"), ("difficulty", "Difficulty"), ("source_dataset", "Source"), ("status_as_of", "Status as of"), ] lines = [] for key, label in meta_fields: value = as_text(row.get(key)) if value: lines.append(f"{label}: {value}") if lines: return f"{prompt}\n\nMetadata:\n" + "\n".join(lines) return prompt def build_prompt_text(row: Dict[str, Any], tokenizer: AutoTokenizer, data_cfg: Dict[str, Any]) -> str: system_prompt = as_text(data_cfg.get("system_prompt")) if not system_prompt: system_prompt = "You are a rigorous mathematical reasoning assistant." user_block = build_user_block(row, data_cfg) if getattr(tokenizer, "chat_template", None): messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_block}, ] return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) return f"System:\n{system_prompt}\n\nUser:\n{user_block}\n\nAssistant:\n" def extract_candidate_text(full_generation: str, prompt_text: str) -> str: if full_generation.startswith(prompt_text): return full_generation[len(prompt_text) :].strip() return full_generation.strip() def load_model_and_tokenizer( base_model: str, adapter_path: Optional[Path], trust_remote_code: bool, ) -> Tuple[Any, AutoTokenizer]: tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=trust_remote_code, use_fast=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token if tokenizer.pad_token is None: tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) model = AutoModelForCausalLM.from_pretrained( base_model, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=trust_remote_code, ) if adapter_path is not None: model = PeftModel.from_pretrained(model, str(adapter_path)) model.eval() return model, tokenizer def make_bucket() -> Dict[str, Any]: return { "evaluated_rows": 0, "pass_at_1_hits": 0, "pass_at_k_hits": 0, "exact_at_1_hits": 0, "exact_at_k_hits": 0, "boxed_at_k_hits": 0, } def update_bucket(bucket: Dict[str, Any], hit1: bool, hitk: bool, exact1: bool, exactk: bool, boxedk: bool) -> None: bucket["evaluated_rows"] += 1 if hit1: bucket["pass_at_1_hits"] += 1 if hitk: bucket["pass_at_k_hits"] += 1 if exact1: bucket["exact_at_1_hits"] += 1 if exactk: bucket["exact_at_k_hits"] += 1 if boxedk: bucket["boxed_at_k_hits"] += 1 def finalize_bucket(bucket: Dict[str, Any]) -> Dict[str, Any]: total = max(int(bucket.get("evaluated_rows", 0)), 1) rows = int(bucket.get("evaluated_rows", 0)) return { "evaluated_rows": rows, "pass_at_1": float(bucket.get("pass_at_1_hits", 0)) / total, "pass_at_k": float(bucket.get("pass_at_k_hits", 0)) / total, "exact_at_1": float(bucket.get("exact_at_1_hits", 0)) / total, "exact_at_k": float(bucket.get("exact_at_k_hits", 0)) / total, "boxed_at_k": float(bucket.get("boxed_at_k_hits", 0)) / total, } def resolve_eval_file(arg_eval_file: Optional[Path], cfg: Dict[str, Any]) -> Path: if arg_eval_file is not None: return arg_eval_file post_eval_cfg = cfg.get("post_eval", {}) data_cfg = cfg.get("data", {}) for candidate in ( as_text(post_eval_cfg.get("eval_file")), as_text(data_cfg.get("default_validation_file")), "data/releases/v1/test.parquet", "workspace/data/releases/v1/test.parquet", ): if not candidate: continue path = Path(candidate) if path.exists(): return path return Path("data/releases/v1/test.parquet") def run_evaluation(args: argparse.Namespace) -> Dict[str, Any]: if args.k < 1: raise ValueError("--k must be >= 1.") if args.max_samples < 1: raise ValueError("--max-samples must be >= 1.") if args.max_new_tokens < 1: raise ValueError("--max-new-tokens must be >= 1.") if args.max_input_length < 128: raise ValueError("--max-input-length must be >= 128.") if args.temperature <= 0: raise ValueError("--temperature must be > 0.") if not 0 < args.top_p <= 1: raise ValueError("--top-p must be in (0, 1].") cfg = load_config(args.config) data_cfg = cfg.get("data", {}) model_cfg = cfg.get("model", {}) set_seed(args.seed) base_model = args.base_model or as_text(model_cfg.get("base_model")) if not base_model: raise ValueError("Base model is required via --base-model or config.model.base_model.") if args.adapter_path is not None and not args.adapter_path.exists(): raise FileNotFoundError(f"Adapter path not found: {args.adapter_path}") eval_file = resolve_eval_file(args.eval_file, cfg) if not eval_file.exists(): raise FileNotFoundError(f"Evaluation file not found: {eval_file}") model, tokenizer = load_model_and_tokenizer( base_model=base_model, adapter_path=args.adapter_path, trust_remote_code=bool(model_cfg.get("trust_remote_code", False)), ) ds = load_dataset("parquet", data_files={"eval": str(eval_file)})["eval"] if args.max_samples > 0 and args.max_samples < len(ds): ds = ds.select(range(args.max_samples)) totals = make_bucket() family_buckets: Dict[str, Dict[str, Any]] = {} difficulty_buckets: Dict[str, Dict[str, Any]] = {} processed_rows = 0 skipped_no_expected = 0 samples: List[Dict[str, Any]] = [] model_device = next(model.parameters()).device prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt" for row in ds: expected_values = flatten_expected(row, data_cfg) if not expected_values: skipped_no_expected += 1 continue prompt_text = build_prompt_text(row, tokenizer, data_cfg) inputs = tokenizer( prompt_text, return_tensors="pt", truncation=True, max_length=args.max_input_length, ) inputs = {k: v.to(model_device) for k, v in inputs.items()} with torch.no_grad(): output_ids = model.generate( **inputs, do_sample=True, temperature=args.temperature, top_p=args.top_p, num_return_sequences=args.k, max_new_tokens=args.max_new_tokens, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) generations = tokenizer.batch_decode(output_ids, skip_special_tokens=True) candidates = [extract_candidate_text(text, prompt_text) for text in generations] details = [match_candidate(candidate, expected_values) for candidate in candidates] matches = [bool(item["match"]) for item in details] exacts = [bool(item["exact"]) for item in details] boxed = [bool(item["boxed"]) for item in details] hit1 = bool(matches and matches[0]) hitk = bool(any(matches)) exact1 = bool(exacts and exacts[0]) exactk = bool(any(exacts)) boxedk = bool(any(boxed)) update_bucket(totals, hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk) family = as_text(row.get("family")) or "__unknown__" if family not in family_buckets: family_buckets[family] = make_bucket() update_bucket(family_buckets[family], hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk) difficulty = as_text(row.get("difficulty")) or "__unknown__" if difficulty not in difficulty_buckets: difficulty_buckets[difficulty] = make_bucket() update_bucket( difficulty_buckets[difficulty], hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk, ) processed_rows += 1 if args.progress_every > 0 and processed_rows % args.progress_every == 0: print(f"Progress: evaluated_rows={processed_rows} latest_family={family}") if len(samples) < args.sample_records: samples.append( { "uid": as_text(row.get("uid")), "family": family, "difficulty": difficulty, "prompt": as_text(row.get(prompt_field)), "expected_values": expected_values[:5], "candidates": candidates, "match_details": details, "matches": matches, } ) total_eval = int(totals.get("evaluated_rows", 0)) denominator = max(total_eval, 1) pass_at_1 = float(totals.get("pass_at_1_hits", 0)) / denominator pass_at_k = float(totals.get("pass_at_k_hits", 0)) / denominator exact_at_1 = float(totals.get("exact_at_1_hits", 0)) / denominator exact_at_k = float(totals.get("exact_at_k_hits", 0)) / denominator boxed_at_k = float(totals.get("boxed_at_k_hits", 0)) / denominator composite_score = 0.30 * pass_at_1 + 0.50 * pass_at_k + 0.20 * exact_at_k report: Dict[str, Any] = { "base_model": base_model, "adapter_path": str(args.adapter_path) if args.adapter_path is not None else None, "eval_file": str(eval_file), "config": str(args.config), "evaluated_rows": total_eval, "skipped_rows_without_targets": skipped_no_expected, "requested_rows": len(ds), "k": args.k, "pass_at_1": pass_at_1, "pass_at_k": pass_at_k, "exact_at_1": exact_at_1, "exact_at_k": exact_at_k, "boxed_at_k": boxed_at_k, "composite_score": composite_score, "temperature": args.temperature, "top_p": args.top_p, "max_new_tokens": args.max_new_tokens, "max_input_length": args.max_input_length, "seed": args.seed, "family_metrics": { key: finalize_bucket(family_buckets[key]) for key in sorted(family_buckets.keys()) }, "difficulty_metrics": { key: finalize_bucket(difficulty_buckets[key]) for key in sorted(difficulty_buckets.keys()) }, "samples": samples, } args.output_json.parent.mkdir(parents=True, exist_ok=True) args.output_json.write_text(json.dumps(report, ensure_ascii=True, indent=2), encoding="utf-8") summary_view = { "evaluated_rows": total_eval, "pass_at_1": pass_at_1, "pass_at_k": pass_at_k, "exact_at_k": exact_at_k, "composite_score": composite_score, "k": args.k, } print(json.dumps(summary_view, indent=2)) print(f"Saved report to {args.output_json}") return report def main() -> None: args = parse_args() run_evaluation(args) if __name__ == "__main__": main()