Upgrade SOTA curriculum: 4-stage training, post-eval metrics, and quality-gated promotion.
a86edac verified | #!/usr/bin/env python3 | |
| """Self-consistency evaluation for math-conjecture model checkpoints.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import re | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Sequence, Tuple | |
| import torch | |
| import yaml | |
| from datasets import load_dataset | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed | |
| SCRIPT_ROOT = Path(__file__).resolve().parents[1] | |
| DEFAULT_CONFIG_PATH = SCRIPT_ROOT / "configs" / "deepseek_math_sota.yaml" | |
| DEFAULT_OUTPUT_JSON = SCRIPT_ROOT / "runs" / "latest_eval_report.json" | |
| BOXED_RE = re.compile(r"\\boxed\{([^{}]+)\}") | |
| SPACE_RE = re.compile(r"\s+") | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Run pass@k-style evaluation on held-out split.") | |
| parser.add_argument( | |
| "--config", | |
| type=Path, | |
| default=DEFAULT_CONFIG_PATH, | |
| help="Training config used for prompt formatting defaults.", | |
| ) | |
| parser.add_argument( | |
| "--base-model", | |
| type=str, | |
| default=None, | |
| help="Override base model id from config.", | |
| ) | |
| parser.add_argument( | |
| "--adapter-path", | |
| type=Path, | |
| default=None, | |
| help="Optional LoRA adapter path to load on top of base model.", | |
| ) | |
| parser.add_argument( | |
| "--eval-file", | |
| type=Path, | |
| default=None, | |
| help="Parquet split used for evaluation (defaults to post_eval.eval_file or data.default_validation_file).", | |
| ) | |
| parser.add_argument("--max-samples", type=int, default=300, help="Maximum evaluation rows.") | |
| parser.add_argument("--k", type=int, default=4, help="Number of sampled generations per prompt.") | |
| parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length cap.") | |
| parser.add_argument("--max-input-length", type=int, default=4096, help="Prompt tokenization length cap.") | |
| parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.") | |
| parser.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling p.") | |
| parser.add_argument("--seed", type=int, default=17, help="Random seed.") | |
| parser.add_argument( | |
| "--progress-every", | |
| type=int, | |
| default=25, | |
| help="Print progress every N evaluated rows (0 disables).", | |
| ) | |
| parser.add_argument( | |
| "--sample-records", | |
| type=int, | |
| default=30, | |
| help="How many sample records to store in report.", | |
| ) | |
| parser.add_argument( | |
| "--output-json", | |
| type=Path, | |
| default=DEFAULT_OUTPUT_JSON, | |
| help="Where to write evaluation report.", | |
| ) | |
| return parser.parse_args() | |
| def as_text(value: Any) -> str: | |
| if value is None: | |
| return "" | |
| if isinstance(value, str): | |
| return value.strip() | |
| return str(value).strip() | |
| def as_float(value: Any, default: float) -> float: | |
| if value is None: | |
| return default | |
| try: | |
| return float(value) | |
| except (TypeError, ValueError): | |
| return default | |
| def as_int(value: Any, default: int) -> int: | |
| if value is None: | |
| return default | |
| try: | |
| return int(value) | |
| except (TypeError, ValueError): | |
| return default | |
| def load_config(path: Path) -> Dict[str, Any]: | |
| cfg = yaml.safe_load(path.read_text(encoding="utf-8")) | |
| if not isinstance(cfg, dict): | |
| raise ValueError("Invalid YAML config.") | |
| return cfg | |
| def normalize_answer(text: str) -> str: | |
| text = text.strip().lower() | |
| text = text.replace("$", "") | |
| text = text.replace("\\left", "").replace("\\right", "") | |
| text = text.replace("\\,", "").replace("\\!", "").replace("\\;", "") | |
| text = SPACE_RE.sub(" ", text) | |
| return text.strip(" .") | |
| def extract_boxed_values(text: str) -> List[str]: | |
| return [normalize_answer(match) for match in BOXED_RE.findall(text or "") if normalize_answer(match)] | |
| def parse_numeric_value(text: str) -> Optional[float]: | |
| normalized = normalize_answer(text) | |
| if not normalized: | |
| return None | |
| candidate = normalized.replace(",", "") | |
| if re.fullmatch(r"[-+]?\d+\s*/\s*[-+]?\d+", candidate): | |
| left, right = candidate.split("/", maxsplit=1) | |
| try: | |
| numerator = float(left.strip()) | |
| denominator = float(right.strip()) | |
| except ValueError: | |
| return None | |
| if denominator == 0: | |
| return None | |
| return numerator / denominator | |
| if re.fullmatch(r"[-+]?(?:\d+\.\d*|\d*\.\d+|\d+)(?:[eE][-+]?\d+)?", candidate): | |
| try: | |
| return float(candidate) | |
| except ValueError: | |
| return None | |
| return None | |
| def approximately_equal(left: float, right: float) -> bool: | |
| tolerance = 1e-6 * max(1.0, abs(left), abs(right)) | |
| return abs(left - right) <= tolerance | |
| def match_candidate(candidate: str, expected_values: Sequence[str]) -> Dict[str, Any]: | |
| cand_norm = normalize_answer(candidate) | |
| if not cand_norm: | |
| return { | |
| "match": False, | |
| "exact": False, | |
| "boxed": False, | |
| "numeric": False, | |
| "reason": "empty_candidate", | |
| } | |
| cand_boxed = extract_boxed_values(candidate) | |
| cand_num = parse_numeric_value(cand_norm) | |
| substring_hit = False | |
| boxed_hit = False | |
| numeric_hit = False | |
| for expected in expected_values: | |
| exp_norm = normalize_answer(expected) | |
| if not exp_norm: | |
| continue | |
| if cand_norm == exp_norm: | |
| return { | |
| "match": True, | |
| "exact": True, | |
| "boxed": exp_norm in cand_boxed, | |
| "numeric": False, | |
| "reason": "exact", | |
| } | |
| if exp_norm in cand_norm or cand_norm in exp_norm: | |
| substring_hit = True | |
| expected_boxed = extract_boxed_values(expected) | |
| for cand_box in cand_boxed: | |
| if cand_box == exp_norm or exp_norm in cand_box or cand_box in exp_norm: | |
| boxed_hit = True | |
| for exp_box in expected_boxed: | |
| if cand_norm == exp_box or exp_box in cand_norm or cand_norm in exp_box: | |
| boxed_hit = True | |
| exp_num = parse_numeric_value(exp_norm) | |
| if cand_num is not None and exp_num is not None and approximately_equal(cand_num, exp_num): | |
| numeric_hit = True | |
| if boxed_hit: | |
| return { | |
| "match": True, | |
| "exact": False, | |
| "boxed": True, | |
| "numeric": numeric_hit, | |
| "reason": "boxed", | |
| } | |
| if numeric_hit: | |
| return { | |
| "match": True, | |
| "exact": False, | |
| "boxed": False, | |
| "numeric": True, | |
| "reason": "numeric", | |
| } | |
| if substring_hit: | |
| return { | |
| "match": True, | |
| "exact": False, | |
| "boxed": False, | |
| "numeric": False, | |
| "reason": "substring", | |
| } | |
| return { | |
| "match": False, | |
| "exact": False, | |
| "boxed": False, | |
| "numeric": False, | |
| "reason": "no_match", | |
| } | |
| def flatten_expected(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> List[str]: | |
| out: List[str] = [] | |
| final_field = as_text(data_cfg.get("final_answer_field")) or "final_answer" | |
| target_field = as_text(data_cfg.get("target_field")) or "target" | |
| final_answer = row.get(final_field) | |
| if final_answer is not None: | |
| txt = as_text(final_answer) | |
| if txt: | |
| out.append(txt) | |
| target = row.get(target_field) | |
| if target is None: | |
| return out | |
| if isinstance(target, str): | |
| stripped = target.strip() | |
| if not stripped: | |
| return out | |
| try: | |
| target = json.loads(stripped) | |
| except json.JSONDecodeError: | |
| out.append(stripped) | |
| return out | |
| if isinstance(target, dict): | |
| for value in target.values(): | |
| if isinstance(value, list): | |
| for item in value: | |
| txt = as_text(item) | |
| if txt: | |
| out.append(txt) | |
| else: | |
| txt = as_text(value) | |
| if txt: | |
| out.append(txt) | |
| elif isinstance(target, list): | |
| for item in target: | |
| txt = as_text(item) | |
| if txt: | |
| out.append(txt) | |
| else: | |
| txt = as_text(target) | |
| if txt: | |
| out.append(txt) | |
| return out | |
| def build_user_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str: | |
| prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt" | |
| prompt = as_text(row.get(prompt_field)) | |
| if not prompt: | |
| prompt = "Solve the math task." | |
| meta_fields = [ | |
| ("task_type", "Task type"), | |
| ("family", "Family"), | |
| ("difficulty", "Difficulty"), | |
| ("source_dataset", "Source"), | |
| ("status_as_of", "Status as of"), | |
| ] | |
| lines = [] | |
| for key, label in meta_fields: | |
| value = as_text(row.get(key)) | |
| if value: | |
| lines.append(f"{label}: {value}") | |
| if lines: | |
| return f"{prompt}\n\nMetadata:\n" + "\n".join(lines) | |
| return prompt | |
| def build_prompt_text(row: Dict[str, Any], tokenizer: AutoTokenizer, data_cfg: Dict[str, Any]) -> str: | |
| system_prompt = as_text(data_cfg.get("system_prompt")) | |
| if not system_prompt: | |
| system_prompt = "You are a rigorous mathematical reasoning assistant." | |
| user_block = build_user_block(row, data_cfg) | |
| if getattr(tokenizer, "chat_template", None): | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_block}, | |
| ] | |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| return f"System:\n{system_prompt}\n\nUser:\n{user_block}\n\nAssistant:\n" | |
| def extract_candidate_text(full_generation: str, prompt_text: str) -> str: | |
| if full_generation.startswith(prompt_text): | |
| return full_generation[len(prompt_text) :].strip() | |
| return full_generation.strip() | |
| def load_model_and_tokenizer( | |
| base_model: str, | |
| adapter_path: Optional[Path], | |
| trust_remote_code: bool, | |
| ) -> Tuple[Any, AutoTokenizer]: | |
| tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=trust_remote_code, use_fast=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token | |
| if tokenizer.pad_token is None: | |
| tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| trust_remote_code=trust_remote_code, | |
| ) | |
| if adapter_path is not None: | |
| model = PeftModel.from_pretrained(model, str(adapter_path)) | |
| model.eval() | |
| return model, tokenizer | |
| def make_bucket() -> Dict[str, Any]: | |
| return { | |
| "evaluated_rows": 0, | |
| "pass_at_1_hits": 0, | |
| "pass_at_k_hits": 0, | |
| "exact_at_1_hits": 0, | |
| "exact_at_k_hits": 0, | |
| "boxed_at_k_hits": 0, | |
| } | |
| def update_bucket(bucket: Dict[str, Any], hit1: bool, hitk: bool, exact1: bool, exactk: bool, boxedk: bool) -> None: | |
| bucket["evaluated_rows"] += 1 | |
| if hit1: | |
| bucket["pass_at_1_hits"] += 1 | |
| if hitk: | |
| bucket["pass_at_k_hits"] += 1 | |
| if exact1: | |
| bucket["exact_at_1_hits"] += 1 | |
| if exactk: | |
| bucket["exact_at_k_hits"] += 1 | |
| if boxedk: | |
| bucket["boxed_at_k_hits"] += 1 | |
| def finalize_bucket(bucket: Dict[str, Any]) -> Dict[str, Any]: | |
| total = max(int(bucket.get("evaluated_rows", 0)), 1) | |
| rows = int(bucket.get("evaluated_rows", 0)) | |
| return { | |
| "evaluated_rows": rows, | |
| "pass_at_1": float(bucket.get("pass_at_1_hits", 0)) / total, | |
| "pass_at_k": float(bucket.get("pass_at_k_hits", 0)) / total, | |
| "exact_at_1": float(bucket.get("exact_at_1_hits", 0)) / total, | |
| "exact_at_k": float(bucket.get("exact_at_k_hits", 0)) / total, | |
| "boxed_at_k": float(bucket.get("boxed_at_k_hits", 0)) / total, | |
| } | |
| def resolve_eval_file(arg_eval_file: Optional[Path], cfg: Dict[str, Any]) -> Path: | |
| if arg_eval_file is not None: | |
| return arg_eval_file | |
| post_eval_cfg = cfg.get("post_eval", {}) | |
| data_cfg = cfg.get("data", {}) | |
| for candidate in ( | |
| as_text(post_eval_cfg.get("eval_file")), | |
| as_text(data_cfg.get("default_validation_file")), | |
| "data/releases/v1/test.parquet", | |
| "workspace/data/releases/v1/test.parquet", | |
| ): | |
| if not candidate: | |
| continue | |
| path = Path(candidate) | |
| if path.exists(): | |
| return path | |
| return Path("data/releases/v1/test.parquet") | |
| def run_evaluation(args: argparse.Namespace) -> Dict[str, Any]: | |
| if args.k < 1: | |
| raise ValueError("--k must be >= 1.") | |
| if args.max_samples < 1: | |
| raise ValueError("--max-samples must be >= 1.") | |
| if args.max_new_tokens < 1: | |
| raise ValueError("--max-new-tokens must be >= 1.") | |
| if args.max_input_length < 128: | |
| raise ValueError("--max-input-length must be >= 128.") | |
| if args.temperature <= 0: | |
| raise ValueError("--temperature must be > 0.") | |
| if not 0 < args.top_p <= 1: | |
| raise ValueError("--top-p must be in (0, 1].") | |
| cfg = load_config(args.config) | |
| data_cfg = cfg.get("data", {}) | |
| model_cfg = cfg.get("model", {}) | |
| set_seed(args.seed) | |
| base_model = args.base_model or as_text(model_cfg.get("base_model")) | |
| if not base_model: | |
| raise ValueError("Base model is required via --base-model or config.model.base_model.") | |
| if args.adapter_path is not None and not args.adapter_path.exists(): | |
| raise FileNotFoundError(f"Adapter path not found: {args.adapter_path}") | |
| eval_file = resolve_eval_file(args.eval_file, cfg) | |
| if not eval_file.exists(): | |
| raise FileNotFoundError(f"Evaluation file not found: {eval_file}") | |
| model, tokenizer = load_model_and_tokenizer( | |
| base_model=base_model, | |
| adapter_path=args.adapter_path, | |
| trust_remote_code=bool(model_cfg.get("trust_remote_code", False)), | |
| ) | |
| ds = load_dataset("parquet", data_files={"eval": str(eval_file)})["eval"] | |
| if args.max_samples > 0 and args.max_samples < len(ds): | |
| ds = ds.select(range(args.max_samples)) | |
| totals = make_bucket() | |
| family_buckets: Dict[str, Dict[str, Any]] = {} | |
| difficulty_buckets: Dict[str, Dict[str, Any]] = {} | |
| processed_rows = 0 | |
| skipped_no_expected = 0 | |
| samples: List[Dict[str, Any]] = [] | |
| model_device = next(model.parameters()).device | |
| prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt" | |
| for row in ds: | |
| expected_values = flatten_expected(row, data_cfg) | |
| if not expected_values: | |
| skipped_no_expected += 1 | |
| continue | |
| prompt_text = build_prompt_text(row, tokenizer, data_cfg) | |
| inputs = tokenizer( | |
| prompt_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=args.max_input_length, | |
| ) | |
| inputs = {k: v.to(model_device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| do_sample=True, | |
| temperature=args.temperature, | |
| top_p=args.top_p, | |
| num_return_sequences=args.k, | |
| max_new_tokens=args.max_new_tokens, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| generations = tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
| candidates = [extract_candidate_text(text, prompt_text) for text in generations] | |
| details = [match_candidate(candidate, expected_values) for candidate in candidates] | |
| matches = [bool(item["match"]) for item in details] | |
| exacts = [bool(item["exact"]) for item in details] | |
| boxed = [bool(item["boxed"]) for item in details] | |
| hit1 = bool(matches and matches[0]) | |
| hitk = bool(any(matches)) | |
| exact1 = bool(exacts and exacts[0]) | |
| exactk = bool(any(exacts)) | |
| boxedk = bool(any(boxed)) | |
| update_bucket(totals, hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk) | |
| family = as_text(row.get("family")) or "__unknown__" | |
| if family not in family_buckets: | |
| family_buckets[family] = make_bucket() | |
| update_bucket(family_buckets[family], hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk) | |
| difficulty = as_text(row.get("difficulty")) or "__unknown__" | |
| if difficulty not in difficulty_buckets: | |
| difficulty_buckets[difficulty] = make_bucket() | |
| update_bucket( | |
| difficulty_buckets[difficulty], | |
| hit1=hit1, | |
| hitk=hitk, | |
| exact1=exact1, | |
| exactk=exactk, | |
| boxedk=boxedk, | |
| ) | |
| processed_rows += 1 | |
| if args.progress_every > 0 and processed_rows % args.progress_every == 0: | |
| print(f"Progress: evaluated_rows={processed_rows} latest_family={family}") | |
| if len(samples) < args.sample_records: | |
| samples.append( | |
| { | |
| "uid": as_text(row.get("uid")), | |
| "family": family, | |
| "difficulty": difficulty, | |
| "prompt": as_text(row.get(prompt_field)), | |
| "expected_values": expected_values[:5], | |
| "candidates": candidates, | |
| "match_details": details, | |
| "matches": matches, | |
| } | |
| ) | |
| total_eval = int(totals.get("evaluated_rows", 0)) | |
| denominator = max(total_eval, 1) | |
| pass_at_1 = float(totals.get("pass_at_1_hits", 0)) / denominator | |
| pass_at_k = float(totals.get("pass_at_k_hits", 0)) / denominator | |
| exact_at_1 = float(totals.get("exact_at_1_hits", 0)) / denominator | |
| exact_at_k = float(totals.get("exact_at_k_hits", 0)) / denominator | |
| boxed_at_k = float(totals.get("boxed_at_k_hits", 0)) / denominator | |
| composite_score = 0.30 * pass_at_1 + 0.50 * pass_at_k + 0.20 * exact_at_k | |
| report: Dict[str, Any] = { | |
| "base_model": base_model, | |
| "adapter_path": str(args.adapter_path) if args.adapter_path is not None else None, | |
| "eval_file": str(eval_file), | |
| "config": str(args.config), | |
| "evaluated_rows": total_eval, | |
| "skipped_rows_without_targets": skipped_no_expected, | |
| "requested_rows": len(ds), | |
| "k": args.k, | |
| "pass_at_1": pass_at_1, | |
| "pass_at_k": pass_at_k, | |
| "exact_at_1": exact_at_1, | |
| "exact_at_k": exact_at_k, | |
| "boxed_at_k": boxed_at_k, | |
| "composite_score": composite_score, | |
| "temperature": args.temperature, | |
| "top_p": args.top_p, | |
| "max_new_tokens": args.max_new_tokens, | |
| "max_input_length": args.max_input_length, | |
| "seed": args.seed, | |
| "family_metrics": { | |
| key: finalize_bucket(family_buckets[key]) | |
| for key in sorted(family_buckets.keys()) | |
| }, | |
| "difficulty_metrics": { | |
| key: finalize_bucket(difficulty_buckets[key]) | |
| for key in sorted(difficulty_buckets.keys()) | |
| }, | |
| "samples": samples, | |
| } | |
| args.output_json.parent.mkdir(parents=True, exist_ok=True) | |
| args.output_json.write_text(json.dumps(report, ensure_ascii=True, indent=2), encoding="utf-8") | |
| summary_view = { | |
| "evaluated_rows": total_eval, | |
| "pass_at_1": pass_at_1, | |
| "pass_at_k": pass_at_k, | |
| "exact_at_k": exact_at_k, | |
| "composite_score": composite_score, | |
| "k": args.k, | |
| } | |
| print(json.dumps(summary_view, indent=2)) | |
| print(f"Saved report to {args.output_json}") | |
| return report | |
| def main() -> None: | |
| args = parse_args() | |
| run_evaluation(args) | |
| if __name__ == "__main__": | |
| main() | |