math_trainer / scripts /eval_sota.py
NorthernTribe-Research's picture
Upgrade training pipeline with post-eval quality gates and tactical UI controls.
0be512a verified
#!/usr/bin/env python3
"""Self-consistency evaluation for math-conjecture model checkpoints."""
from __future__ import annotations
import argparse
import json
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
import torch
import yaml
from datasets import load_dataset
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
SCRIPT_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_CONFIG_PATH = SCRIPT_ROOT / "configs" / "deepseek_math_sota.yaml"
DEFAULT_OUTPUT_JSON = SCRIPT_ROOT / "runs" / "latest_eval_report.json"
BOXED_RE = re.compile(r"\\boxed\{([^{}]+)\}")
SPACE_RE = re.compile(r"\s+")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run pass@k-style evaluation on held-out split.")
parser.add_argument(
"--config",
type=Path,
default=DEFAULT_CONFIG_PATH,
help="Training config used for prompt formatting defaults.",
)
parser.add_argument(
"--base-model",
type=str,
default=None,
help="Override base model id from config.",
)
parser.add_argument(
"--adapter-path",
type=Path,
default=None,
help="Optional LoRA adapter path to load on top of base model.",
)
parser.add_argument(
"--eval-file",
type=Path,
default=None,
help="Parquet split used for evaluation (defaults to post_eval.eval_file or data.default_validation_file).",
)
parser.add_argument("--max-samples", type=int, default=300, help="Maximum evaluation rows.")
parser.add_argument("--k", type=int, default=4, help="Number of sampled generations per prompt.")
parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length cap.")
parser.add_argument("--max-input-length", type=int, default=4096, help="Prompt tokenization length cap.")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.")
parser.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling p.")
parser.add_argument("--seed", type=int, default=17, help="Random seed.")
parser.add_argument(
"--progress-every",
type=int,
default=25,
help="Print progress every N evaluated rows (0 disables).",
)
parser.add_argument(
"--sample-records",
type=int,
default=30,
help="How many sample records to store in report.",
)
parser.add_argument(
"--output-json",
type=Path,
default=DEFAULT_OUTPUT_JSON,
help="Where to write evaluation report.",
)
return parser.parse_args()
def as_text(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value.strip()
return str(value).strip()
def as_float(value: Any, default: float) -> float:
if value is None:
return default
try:
return float(value)
except (TypeError, ValueError):
return default
def as_int(value: Any, default: int) -> int:
if value is None:
return default
try:
return int(value)
except (TypeError, ValueError):
return default
def load_config(path: Path) -> Dict[str, Any]:
cfg = yaml.safe_load(path.read_text(encoding="utf-8"))
if not isinstance(cfg, dict):
raise ValueError("Invalid YAML config.")
return cfg
def normalize_answer(text: str) -> str:
text = text.strip().lower()
text = text.replace("$", "")
text = text.replace("\\left", "").replace("\\right", "")
text = text.replace("\\,", "").replace("\\!", "").replace("\\;", "")
text = SPACE_RE.sub(" ", text)
return text.strip(" .")
def extract_boxed_values(text: str) -> List[str]:
return [normalize_answer(match) for match in BOXED_RE.findall(text or "") if normalize_answer(match)]
def parse_numeric_value(text: str) -> Optional[float]:
normalized = normalize_answer(text)
if not normalized:
return None
candidate = normalized.replace(",", "")
if re.fullmatch(r"[-+]?\d+\s*/\s*[-+]?\d+", candidate):
left, right = candidate.split("/", maxsplit=1)
try:
numerator = float(left.strip())
denominator = float(right.strip())
except ValueError:
return None
if denominator == 0:
return None
return numerator / denominator
if re.fullmatch(r"[-+]?(?:\d+\.\d*|\d*\.\d+|\d+)(?:[eE][-+]?\d+)?", candidate):
try:
return float(candidate)
except ValueError:
return None
return None
def approximately_equal(left: float, right: float) -> bool:
tolerance = 1e-6 * max(1.0, abs(left), abs(right))
return abs(left - right) <= tolerance
def match_candidate(candidate: str, expected_values: Sequence[str]) -> Dict[str, Any]:
cand_norm = normalize_answer(candidate)
if not cand_norm:
return {
"match": False,
"exact": False,
"boxed": False,
"numeric": False,
"reason": "empty_candidate",
}
cand_boxed = extract_boxed_values(candidate)
cand_num = parse_numeric_value(cand_norm)
substring_hit = False
boxed_hit = False
numeric_hit = False
for expected in expected_values:
exp_norm = normalize_answer(expected)
if not exp_norm:
continue
if cand_norm == exp_norm:
return {
"match": True,
"exact": True,
"boxed": exp_norm in cand_boxed,
"numeric": False,
"reason": "exact",
}
if exp_norm in cand_norm or cand_norm in exp_norm:
substring_hit = True
expected_boxed = extract_boxed_values(expected)
for cand_box in cand_boxed:
if cand_box == exp_norm or exp_norm in cand_box or cand_box in exp_norm:
boxed_hit = True
for exp_box in expected_boxed:
if cand_norm == exp_box or exp_box in cand_norm or cand_norm in exp_box:
boxed_hit = True
exp_num = parse_numeric_value(exp_norm)
if cand_num is not None and exp_num is not None and approximately_equal(cand_num, exp_num):
numeric_hit = True
if boxed_hit:
return {
"match": True,
"exact": False,
"boxed": True,
"numeric": numeric_hit,
"reason": "boxed",
}
if numeric_hit:
return {
"match": True,
"exact": False,
"boxed": False,
"numeric": True,
"reason": "numeric",
}
if substring_hit:
return {
"match": True,
"exact": False,
"boxed": False,
"numeric": False,
"reason": "substring",
}
return {
"match": False,
"exact": False,
"boxed": False,
"numeric": False,
"reason": "no_match",
}
def flatten_expected(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> List[str]:
out: List[str] = []
final_field = as_text(data_cfg.get("final_answer_field")) or "final_answer"
target_field = as_text(data_cfg.get("target_field")) or "target"
final_answer = row.get(final_field)
if final_answer is not None:
txt = as_text(final_answer)
if txt:
out.append(txt)
target = row.get(target_field)
if target is None:
return out
if isinstance(target, str):
stripped = target.strip()
if not stripped:
return out
try:
target = json.loads(stripped)
except json.JSONDecodeError:
out.append(stripped)
return out
if isinstance(target, dict):
for value in target.values():
if isinstance(value, list):
for item in value:
txt = as_text(item)
if txt:
out.append(txt)
else:
txt = as_text(value)
if txt:
out.append(txt)
elif isinstance(target, list):
for item in target:
txt = as_text(item)
if txt:
out.append(txt)
else:
txt = as_text(target)
if txt:
out.append(txt)
return out
def build_user_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str:
prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt"
prompt = as_text(row.get(prompt_field))
if not prompt:
prompt = "Solve the math task."
meta_fields = [
("task_type", "Task type"),
("family", "Family"),
("difficulty", "Difficulty"),
("source_dataset", "Source"),
("status_as_of", "Status as of"),
]
lines = []
for key, label in meta_fields:
value = as_text(row.get(key))
if value:
lines.append(f"{label}: {value}")
if lines:
return f"{prompt}\n\nMetadata:\n" + "\n".join(lines)
return prompt
def build_prompt_text(row: Dict[str, Any], tokenizer: AutoTokenizer, data_cfg: Dict[str, Any]) -> str:
system_prompt = as_text(data_cfg.get("system_prompt"))
if not system_prompt:
system_prompt = "You are a rigorous mathematical reasoning assistant."
user_block = build_user_block(row, data_cfg)
if getattr(tokenizer, "chat_template", None):
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_block},
]
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return f"System:\n{system_prompt}\n\nUser:\n{user_block}\n\nAssistant:\n"
def extract_candidate_text(full_generation: str, prompt_text: str) -> str:
if full_generation.startswith(prompt_text):
return full_generation[len(prompt_text) :].strip()
return full_generation.strip()
def load_model_and_tokenizer(
base_model: str,
adapter_path: Optional[Path],
trust_remote_code: bool,
) -> Tuple[Any, AutoTokenizer]:
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=trust_remote_code, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
model = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=trust_remote_code,
)
if adapter_path is not None:
model = PeftModel.from_pretrained(model, str(adapter_path))
model.eval()
return model, tokenizer
def make_bucket() -> Dict[str, Any]:
return {
"evaluated_rows": 0,
"pass_at_1_hits": 0,
"pass_at_k_hits": 0,
"exact_at_1_hits": 0,
"exact_at_k_hits": 0,
"boxed_at_k_hits": 0,
}
def update_bucket(bucket: Dict[str, Any], hit1: bool, hitk: bool, exact1: bool, exactk: bool, boxedk: bool) -> None:
bucket["evaluated_rows"] += 1
if hit1:
bucket["pass_at_1_hits"] += 1
if hitk:
bucket["pass_at_k_hits"] += 1
if exact1:
bucket["exact_at_1_hits"] += 1
if exactk:
bucket["exact_at_k_hits"] += 1
if boxedk:
bucket["boxed_at_k_hits"] += 1
def finalize_bucket(bucket: Dict[str, Any]) -> Dict[str, Any]:
total = max(int(bucket.get("evaluated_rows", 0)), 1)
rows = int(bucket.get("evaluated_rows", 0))
return {
"evaluated_rows": rows,
"pass_at_1": float(bucket.get("pass_at_1_hits", 0)) / total,
"pass_at_k": float(bucket.get("pass_at_k_hits", 0)) / total,
"exact_at_1": float(bucket.get("exact_at_1_hits", 0)) / total,
"exact_at_k": float(bucket.get("exact_at_k_hits", 0)) / total,
"boxed_at_k": float(bucket.get("boxed_at_k_hits", 0)) / total,
}
def resolve_eval_file(arg_eval_file: Optional[Path], cfg: Dict[str, Any]) -> Path:
if arg_eval_file is not None:
return arg_eval_file
post_eval_cfg = cfg.get("post_eval", {})
data_cfg = cfg.get("data", {})
for candidate in (
as_text(post_eval_cfg.get("eval_file")),
as_text(data_cfg.get("default_validation_file")),
"data/releases/v1/test.parquet",
"workspace/data/releases/v1/test.parquet",
):
if not candidate:
continue
path = Path(candidate)
if path.exists():
return path
return Path("data/releases/v1/test.parquet")
def run_evaluation(args: argparse.Namespace) -> Dict[str, Any]:
if args.k < 1:
raise ValueError("--k must be >= 1.")
if args.max_samples < 1:
raise ValueError("--max-samples must be >= 1.")
if args.max_new_tokens < 1:
raise ValueError("--max-new-tokens must be >= 1.")
if args.max_input_length < 128:
raise ValueError("--max-input-length must be >= 128.")
if args.temperature <= 0:
raise ValueError("--temperature must be > 0.")
if not 0 < args.top_p <= 1:
raise ValueError("--top-p must be in (0, 1].")
cfg = load_config(args.config)
data_cfg = cfg.get("data", {})
model_cfg = cfg.get("model", {})
set_seed(args.seed)
base_model = args.base_model or as_text(model_cfg.get("base_model"))
if not base_model:
raise ValueError("Base model is required via --base-model or config.model.base_model.")
if args.adapter_path is not None and not args.adapter_path.exists():
raise FileNotFoundError(f"Adapter path not found: {args.adapter_path}")
eval_file = resolve_eval_file(args.eval_file, cfg)
if not eval_file.exists():
raise FileNotFoundError(f"Evaluation file not found: {eval_file}")
model, tokenizer = load_model_and_tokenizer(
base_model=base_model,
adapter_path=args.adapter_path,
trust_remote_code=bool(model_cfg.get("trust_remote_code", False)),
)
ds = load_dataset("parquet", data_files={"eval": str(eval_file)})["eval"]
if args.max_samples > 0 and args.max_samples < len(ds):
ds = ds.select(range(args.max_samples))
totals = make_bucket()
family_buckets: Dict[str, Dict[str, Any]] = {}
difficulty_buckets: Dict[str, Dict[str, Any]] = {}
processed_rows = 0
skipped_no_expected = 0
samples: List[Dict[str, Any]] = []
model_device = next(model.parameters()).device
prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt"
for row in ds:
expected_values = flatten_expected(row, data_cfg)
if not expected_values:
skipped_no_expected += 1
continue
prompt_text = build_prompt_text(row, tokenizer, data_cfg)
inputs = tokenizer(
prompt_text,
return_tensors="pt",
truncation=True,
max_length=args.max_input_length,
)
inputs = {k: v.to(model_device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = model.generate(
**inputs,
do_sample=True,
temperature=args.temperature,
top_p=args.top_p,
num_return_sequences=args.k,
max_new_tokens=args.max_new_tokens,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
generations = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
candidates = [extract_candidate_text(text, prompt_text) for text in generations]
details = [match_candidate(candidate, expected_values) for candidate in candidates]
matches = [bool(item["match"]) for item in details]
exacts = [bool(item["exact"]) for item in details]
boxed = [bool(item["boxed"]) for item in details]
hit1 = bool(matches and matches[0])
hitk = bool(any(matches))
exact1 = bool(exacts and exacts[0])
exactk = bool(any(exacts))
boxedk = bool(any(boxed))
update_bucket(totals, hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk)
family = as_text(row.get("family")) or "__unknown__"
if family not in family_buckets:
family_buckets[family] = make_bucket()
update_bucket(family_buckets[family], hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk)
difficulty = as_text(row.get("difficulty")) or "__unknown__"
if difficulty not in difficulty_buckets:
difficulty_buckets[difficulty] = make_bucket()
update_bucket(
difficulty_buckets[difficulty],
hit1=hit1,
hitk=hitk,
exact1=exact1,
exactk=exactk,
boxedk=boxedk,
)
processed_rows += 1
if args.progress_every > 0 and processed_rows % args.progress_every == 0:
print(f"Progress: evaluated_rows={processed_rows} latest_family={family}")
if len(samples) < args.sample_records:
samples.append(
{
"uid": as_text(row.get("uid")),
"family": family,
"difficulty": difficulty,
"prompt": as_text(row.get(prompt_field)),
"expected_values": expected_values[:5],
"candidates": candidates,
"match_details": details,
"matches": matches,
}
)
total_eval = int(totals.get("evaluated_rows", 0))
denominator = max(total_eval, 1)
pass_at_1 = float(totals.get("pass_at_1_hits", 0)) / denominator
pass_at_k = float(totals.get("pass_at_k_hits", 0)) / denominator
exact_at_1 = float(totals.get("exact_at_1_hits", 0)) / denominator
exact_at_k = float(totals.get("exact_at_k_hits", 0)) / denominator
boxed_at_k = float(totals.get("boxed_at_k_hits", 0)) / denominator
composite_score = 0.30 * pass_at_1 + 0.50 * pass_at_k + 0.20 * exact_at_k
report: Dict[str, Any] = {
"base_model": base_model,
"adapter_path": str(args.adapter_path) if args.adapter_path is not None else None,
"eval_file": str(eval_file),
"config": str(args.config),
"evaluated_rows": total_eval,
"skipped_rows_without_targets": skipped_no_expected,
"requested_rows": len(ds),
"k": args.k,
"pass_at_1": pass_at_1,
"pass_at_k": pass_at_k,
"exact_at_1": exact_at_1,
"exact_at_k": exact_at_k,
"boxed_at_k": boxed_at_k,
"composite_score": composite_score,
"temperature": args.temperature,
"top_p": args.top_p,
"max_new_tokens": args.max_new_tokens,
"max_input_length": args.max_input_length,
"seed": args.seed,
"family_metrics": {
key: finalize_bucket(family_buckets[key])
for key in sorted(family_buckets.keys())
},
"difficulty_metrics": {
key: finalize_bucket(difficulty_buckets[key])
for key in sorted(difficulty_buckets.keys())
},
"samples": samples,
}
args.output_json.parent.mkdir(parents=True, exist_ok=True)
args.output_json.write_text(json.dumps(report, ensure_ascii=True, indent=2), encoding="utf-8")
summary_view = {
"evaluated_rows": total_eval,
"pass_at_1": pass_at_1,
"pass_at_k": pass_at_k,
"exact_at_k": exact_at_k,
"composite_score": composite_score,
"k": args.k,
}
print(json.dumps(summary_view, indent=2))
print(f"Saved report to {args.output_json}")
return report
def main() -> None:
args = parse_args()
run_evaluation(args)
if __name__ == "__main__":
main()