from __future__ import annotations import json import os from pathlib import Path import torch import torch.distributed as dist import yaml from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer import hackable # noqa: F401 from hackable import reward_plugins as reward_plugins_mod from hackable.utils import resolve_repo_path THINKING_SYSTEM_PROMPT = ( "Solve the following math problem.\n" "Think step-by-step inside ... tags.\n" "Then output only the final answer in LaTeX boxed format.\n" "Do not include any words or explanations outside the tags/boxed answer.\n" "Output format must be exactly:\n" "your reasoning\n" "\\boxed{your_final_answer}\n" ) def _load_yaml(path: str) -> dict: with open(path, "r", encoding="utf-8") as handle: return yaml.safe_load(handle) def _dist_info() -> tuple[int, int, int]: rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) return rank, world_size, local_rank def _init_distributed() -> tuple[int, int, int]: rank, world_size, local_rank = _dist_info() if world_size > 1 and not dist.is_initialized(): backend = "nccl" if torch.cuda.is_available() else "gloo" dist.init_process_group(backend=backend, init_method="env://") return rank, world_size, local_rank def _resolve_local_model_dir(base_cfg: dict, model_dir: str) -> Path: candidate = Path(model_dir) if candidate.is_absolute() and candidate.exists(): return candidate.resolve() if not candidate.is_absolute() and candidate.exists(): return candidate.resolve() repo_local = resolve_repo_path(model_dir) if repo_local.exists(): return repo_local cache_root = resolve_repo_path(base_cfg.get("storage", {}).get("cache_dir", "cache")) prefixed = (cache_root / candidate).resolve() if prefixed.exists(): return prefixed raise FileNotFoundError( f"Model directory not found locally: '{model_dir}'. " f"Tried '{candidate}', '{repo_local}', and '{prefixed}'." ) def _build_chat_prompts( tokenizer: AutoTokenizer, questions: list[str], system_prompt: str ) -> list[str]: if getattr(tokenizer, "chat_template", None) is None: raise RuntimeError("Tokenizer has no chat_template; cannot apply chat formatting.") prompts: list[str] = [] for q in questions: messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": q.strip()}, ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) prompts.append(text) return prompts def _load_math_level_rows( level: str, split: str, max_samples: int | None, cache_dir: str | None, ) -> tuple[list[str], list[str]]: dataset_name = "EleutherAI/hendrycks_math" dataset_configs = ( "algebra", "counting_and_probability", "geometry", "intermediate_algebra", "number_theory", "prealgebra", "precalculus", ) questions: list[str] = [] references: list[str] = [] for config_name in dataset_configs: rows = load_dataset( dataset_name, config_name, split=split, cache_dir=cache_dir, ) for row in rows: row_level = str(row.get("level", "")).strip() if row_level != level: continue questions.append(str(row.get("problem", ""))) references.append(str(row.get("solution", ""))) if max_samples is not None and len(questions) >= max_samples: return questions[:max_samples], references[:max_samples] return questions, references @torch.no_grad() def main() -> None: rank, world_size, local_rank = _init_distributed() base_cfg = _load_yaml(str(resolve_repo_path(os.environ["BASE_CONFIG"]))) model_dir = os.environ.get("MODEL_DIR") or os.environ.get("MODEL_PATH") if not model_dir: raise ValueError("Set MODEL_DIR or MODEL_PATH for the checkpoint to evaluate.") resolved_model_dir = _resolve_local_model_dir(base_cfg, model_dir) generation = base_cfg.get("generation", {}) max_prompt_length = int(generation.get("max_prompt_length", 512)) max_new_tokens = int(generation.get("max_completion_length", 256)) max_prompt_length = int(os.environ.get("MAX_PROMPT_LENGTH", str(max_prompt_length))) max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", str(max_new_tokens))) split = os.environ.get("MATH_SPLIT", "test") max_samples_env = os.environ.get("MAX_SAMPLES", os.environ.get("EVAL_MAX_SAMPLES", "-1")) max_samples = None if int(max_samples_env) < 0 else int(max_samples_env) batch_size = int(os.environ.get("BATCH_SIZE", "4")) cache_root = resolve_repo_path(base_cfg.get("storage", {}).get("cache_dir", "cache")) datasets_cache = str(cache_root / "datasets") models_cache = str(cache_root / "models") tokenizer = AutoTokenizer.from_pretrained( str(resolved_model_dir), trust_remote_code=bool(base_cfg.get("model", {}).get("trust_remote_code", False)), cache_dir=models_cache, local_files_only=True, ) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token # Decoder-only safe. tokenizer.padding_side = "left" dtype = torch.bfloat16 if bool(base_cfg.get("trainer", {}).get("bf16", True)) else torch.float16 model = AutoModelForCausalLM.from_pretrained( str(resolved_model_dir), trust_remote_code=bool(base_cfg.get("model", {}).get("trust_remote_code", False)), cache_dir=models_cache, torch_dtype=dtype, local_files_only=True, ) if torch.cuda.is_available(): torch.cuda.set_device(local_rank) device = torch.device(f"cuda:{local_rank}") else: device = torch.device("cpu") model.to(device) model.eval() questions, references = _load_math_level_rows( level="Level 1", split=split, max_samples=max_samples, cache_dir=datasets_cache, ) indices = list(range(rank, len(questions), world_size)) local_questions = [questions[i] for i in indices] local_refs = [references[i] for i in indices] chat_prompts = _build_chat_prompts(tokenizer, local_questions, THINKING_SYSTEM_PROMPT) completions: list[str] = [] for start in range(0, len(chat_prompts), batch_size): batch_prompts = chat_prompts[start : start + batch_size] enc = tokenizer( batch_prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_prompt_length, ) input_ids = enc["input_ids"].to(device) attn = enc["attention_mask"].to(device) prompt_seq_len = input_ids.shape[1] out = model.generate( input_ids=input_ids, attention_mask=attn, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) for bi in range(out.size(0)): gen_ids = out[bi, prompt_seq_len:] completions.append(tokenizer.decode(gen_ids, skip_special_tokens=True)) # Strict boxed correctness (project metric) strict_scores = [] for completion, reference in zip(completions, local_refs, strict=True): pred_text = reward_plugins_mod._extract_predicted_answer_text(completion) ref_text = reward_plugins_mod._extract_reference_answer_text(reference) if not pred_text or not ref_text: strict_scores.append(0.0) continue pred_norm = reward_plugins_mod._normalize_answer_text(pred_text) ref_norm = reward_plugins_mod._normalize_answer_text(ref_text) if pred_norm and ref_norm and pred_norm == ref_norm: strict_scores.append(1.0) continue pred_value = reward_plugins_mod._parse_numeric(pred_text) ref_value = reward_plugins_mod._parse_numeric(ref_text) if pred_value is not None and ref_value is not None and reward_plugins_mod._is_close(pred_value, ref_value): strict_scores.append(1.0) else: strict_scores.append(0.0) # Lenient numeric correctness fallback lenient_scores: list[float] = [] for completion, reference in zip(completions, local_refs, strict=True): ref_val = reward_plugins_mod._extract_reference_target(reference) boxed = reward_plugins_mod._extract_last_boxed(completion) if boxed: pred_val = reward_plugins_mod._parse_numeric(boxed) if pred_val is None: nums = reward_plugins_mod._extract_numbers(boxed) pred_val = nums[-1] if nums else None else: nums = reward_plugins_mod._extract_numbers(completion) pred_val = nums[-1] if nums else None if ref_val is not None and pred_val is not None and reward_plugins_mod._is_close(pred_val, ref_val): lenient_scores.append(1.0) else: lenient_scores.append(0.0) local_records: list[dict] = [] for i, idx in enumerate(indices): local_records.append( { "sample_index": int(idx), "question": local_questions[i], "reference_answer": local_refs[i], "model_answer_raw": completions[i], "correctness": float(lenient_scores[i]), "correctness_strict_boxed": float(strict_scores[i]), } ) if dist.is_initialized(): gathered: list[list[dict] | None] = [None for _ in range(world_size)] dist.all_gather_object(gathered, local_records) merged: list[dict] = [] for part in gathered: if part: merged.extend(part) else: merged = local_records if rank != 0: return merged.sort(key=lambda r: r["sample_index"]) output_path = resolve_repo_path( os.environ.get( "OUTPUT_PATH", "artifacts/eval/math_level1_thinking_zeroshot/answers.jsonl", ) ) output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w", encoding="utf-8") as handle: for row in merged: handle.write(json.dumps(row, ensure_ascii=True) + "\n") acc = sum(r["correctness"] for r in merged) / len(merged) if merged else 0.0 acc_strict = ( sum(r["correctness_strict_boxed"] for r in merged) / len(merged) if merged else 0.0 ) print(f"Wrote {len(merged)} rows to {output_path}") print(f"Accuracy (lenient numeric): {acc:.4f}") print(f"Accuracy (strict boxed): {acc_strict:.4f}") if __name__ == "__main__": main()