| from __future__ import annotations |
|
|
| import json |
| import os |
| from pathlib import Path |
|
|
| import torch |
| import torch.distributed as dist |
| import yaml |
| from datasets import load_dataset |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| import hackable |
| from hackable import reward_plugins as reward_plugins_mod |
| from hackable.utils import resolve_repo_path |
|
|
|
|
| THINKING_SYSTEM_PROMPT = ( |
| "Solve the following math problem.\n" |
| "Think step-by-step inside <think>...</think> tags.\n" |
| "Then output only the final answer in LaTeX boxed format.\n" |
| "Do not include any words or explanations outside the tags/boxed answer.\n" |
| "Output format must be exactly:\n" |
| "<think>your reasoning</think>\n" |
| "\\boxed{your_final_answer}\n" |
| ) |
|
|
|
|
| def _load_yaml(path: str) -> dict: |
| with open(path, "r", encoding="utf-8") as handle: |
| return yaml.safe_load(handle) |
|
|
|
|
| def _dist_info() -> tuple[int, int, int]: |
| rank = int(os.environ.get("RANK", "0")) |
| world_size = int(os.environ.get("WORLD_SIZE", "1")) |
| local_rank = int(os.environ.get("LOCAL_RANK", "0")) |
| return rank, world_size, local_rank |
|
|
|
|
| def _init_distributed() -> tuple[int, int, int]: |
| rank, world_size, local_rank = _dist_info() |
| if world_size > 1 and not dist.is_initialized(): |
| backend = "nccl" if torch.cuda.is_available() else "gloo" |
| dist.init_process_group(backend=backend, init_method="env://") |
| return rank, world_size, local_rank |
|
|
|
|
| def _resolve_local_model_dir(base_cfg: dict, model_dir: str) -> Path: |
| candidate = Path(model_dir) |
| if candidate.is_absolute() and candidate.exists(): |
| return candidate.resolve() |
| if not candidate.is_absolute() and candidate.exists(): |
| return candidate.resolve() |
|
|
| repo_local = resolve_repo_path(model_dir) |
| if repo_local.exists(): |
| return repo_local |
|
|
| cache_root = resolve_repo_path(base_cfg.get("storage", {}).get("cache_dir", "cache")) |
| prefixed = (cache_root / candidate).resolve() |
| if prefixed.exists(): |
| return prefixed |
|
|
| raise FileNotFoundError( |
| f"Model directory not found locally: '{model_dir}'. " |
| f"Tried '{candidate}', '{repo_local}', and '{prefixed}'." |
| ) |
|
|
|
|
| def _build_chat_prompts( |
| tokenizer: AutoTokenizer, questions: list[str], system_prompt: str |
| ) -> list[str]: |
| if getattr(tokenizer, "chat_template", None) is None: |
| raise RuntimeError("Tokenizer has no chat_template; cannot apply chat formatting.") |
|
|
| prompts: list[str] = [] |
| for q in questions: |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": q.strip()}, |
| ] |
| text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| prompts.append(text) |
| return prompts |
|
|
|
|
| def _load_math_level_rows( |
| level: str, |
| split: str, |
| max_samples: int | None, |
| cache_dir: str | None, |
| ) -> tuple[list[str], list[str]]: |
| dataset_name = "EleutherAI/hendrycks_math" |
| dataset_configs = ( |
| "algebra", |
| "counting_and_probability", |
| "geometry", |
| "intermediate_algebra", |
| "number_theory", |
| "prealgebra", |
| "precalculus", |
| ) |
|
|
| questions: list[str] = [] |
| references: list[str] = [] |
|
|
| for config_name in dataset_configs: |
| rows = load_dataset( |
| dataset_name, |
| config_name, |
| split=split, |
| cache_dir=cache_dir, |
| ) |
| for row in rows: |
| row_level = str(row.get("level", "")).strip() |
| if row_level != level: |
| continue |
| questions.append(str(row.get("problem", ""))) |
| references.append(str(row.get("solution", ""))) |
| if max_samples is not None and len(questions) >= max_samples: |
| return questions[:max_samples], references[:max_samples] |
|
|
| return questions, references |
|
|
|
|
| @torch.no_grad() |
| def main() -> None: |
| rank, world_size, local_rank = _init_distributed() |
|
|
| base_cfg = _load_yaml(str(resolve_repo_path(os.environ["BASE_CONFIG"]))) |
| model_dir = os.environ.get("MODEL_DIR") or os.environ.get("MODEL_PATH") |
| if not model_dir: |
| raise ValueError("Set MODEL_DIR or MODEL_PATH for the checkpoint to evaluate.") |
| resolved_model_dir = _resolve_local_model_dir(base_cfg, model_dir) |
|
|
| generation = base_cfg.get("generation", {}) |
| max_prompt_length = int(generation.get("max_prompt_length", 512)) |
| max_new_tokens = int(generation.get("max_completion_length", 256)) |
| max_prompt_length = int(os.environ.get("MAX_PROMPT_LENGTH", str(max_prompt_length))) |
| max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", str(max_new_tokens))) |
|
|
| split = os.environ.get("MATH_SPLIT", "test") |
| max_samples_env = os.environ.get("MAX_SAMPLES", os.environ.get("EVAL_MAX_SAMPLES", "-1")) |
| max_samples = None if int(max_samples_env) < 0 else int(max_samples_env) |
|
|
| batch_size = int(os.environ.get("BATCH_SIZE", "4")) |
|
|
| cache_root = resolve_repo_path(base_cfg.get("storage", {}).get("cache_dir", "cache")) |
| datasets_cache = str(cache_root / "datasets") |
| models_cache = str(cache_root / "models") |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| str(resolved_model_dir), |
| trust_remote_code=bool(base_cfg.get("model", {}).get("trust_remote_code", False)), |
| cache_dir=models_cache, |
| local_files_only=True, |
| ) |
| if tokenizer.pad_token_id is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| tokenizer.padding_side = "left" |
|
|
| dtype = torch.bfloat16 if bool(base_cfg.get("trainer", {}).get("bf16", True)) else torch.float16 |
| model = AutoModelForCausalLM.from_pretrained( |
| str(resolved_model_dir), |
| trust_remote_code=bool(base_cfg.get("model", {}).get("trust_remote_code", False)), |
| cache_dir=models_cache, |
| torch_dtype=dtype, |
| local_files_only=True, |
| ) |
| if torch.cuda.is_available(): |
| torch.cuda.set_device(local_rank) |
| device = torch.device(f"cuda:{local_rank}") |
| else: |
| device = torch.device("cpu") |
| model.to(device) |
| model.eval() |
|
|
| questions, references = _load_math_level_rows( |
| level="Level 1", |
| split=split, |
| max_samples=max_samples, |
| cache_dir=datasets_cache, |
| ) |
|
|
| indices = list(range(rank, len(questions), world_size)) |
| local_questions = [questions[i] for i in indices] |
| local_refs = [references[i] for i in indices] |
|
|
| chat_prompts = _build_chat_prompts(tokenizer, local_questions, THINKING_SYSTEM_PROMPT) |
| completions: list[str] = [] |
|
|
| for start in range(0, len(chat_prompts), batch_size): |
| batch_prompts = chat_prompts[start : start + batch_size] |
| enc = tokenizer( |
| batch_prompts, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=max_prompt_length, |
| ) |
| input_ids = enc["input_ids"].to(device) |
| attn = enc["attention_mask"].to(device) |
| prompt_seq_len = input_ids.shape[1] |
|
|
| out = model.generate( |
| input_ids=input_ids, |
| attention_mask=attn, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
| for bi in range(out.size(0)): |
| gen_ids = out[bi, prompt_seq_len:] |
| completions.append(tokenizer.decode(gen_ids, skip_special_tokens=True)) |
|
|
| |
| strict_scores = [] |
| for completion, reference in zip(completions, local_refs, strict=True): |
| pred_text = reward_plugins_mod._extract_predicted_answer_text(completion) |
| ref_text = reward_plugins_mod._extract_reference_answer_text(reference) |
| if not pred_text or not ref_text: |
| strict_scores.append(0.0) |
| continue |
| pred_norm = reward_plugins_mod._normalize_answer_text(pred_text) |
| ref_norm = reward_plugins_mod._normalize_answer_text(ref_text) |
| if pred_norm and ref_norm and pred_norm == ref_norm: |
| strict_scores.append(1.0) |
| continue |
| pred_value = reward_plugins_mod._parse_numeric(pred_text) |
| ref_value = reward_plugins_mod._parse_numeric(ref_text) |
| if pred_value is not None and ref_value is not None and reward_plugins_mod._is_close(pred_value, ref_value): |
| strict_scores.append(1.0) |
| else: |
| strict_scores.append(0.0) |
|
|
| |
| lenient_scores: list[float] = [] |
| for completion, reference in zip(completions, local_refs, strict=True): |
| ref_val = reward_plugins_mod._extract_reference_target(reference) |
| boxed = reward_plugins_mod._extract_last_boxed(completion) |
| if boxed: |
| pred_val = reward_plugins_mod._parse_numeric(boxed) |
| if pred_val is None: |
| nums = reward_plugins_mod._extract_numbers(boxed) |
| pred_val = nums[-1] if nums else None |
| else: |
| nums = reward_plugins_mod._extract_numbers(completion) |
| pred_val = nums[-1] if nums else None |
|
|
| if ref_val is not None and pred_val is not None and reward_plugins_mod._is_close(pred_val, ref_val): |
| lenient_scores.append(1.0) |
| else: |
| lenient_scores.append(0.0) |
|
|
| local_records: list[dict] = [] |
| for i, idx in enumerate(indices): |
| local_records.append( |
| { |
| "sample_index": int(idx), |
| "question": local_questions[i], |
| "reference_answer": local_refs[i], |
| "model_answer_raw": completions[i], |
| "correctness": float(lenient_scores[i]), |
| "correctness_strict_boxed": float(strict_scores[i]), |
| } |
| ) |
|
|
| if dist.is_initialized(): |
| gathered: list[list[dict] | None] = [None for _ in range(world_size)] |
| dist.all_gather_object(gathered, local_records) |
| merged: list[dict] = [] |
| for part in gathered: |
| if part: |
| merged.extend(part) |
| else: |
| merged = local_records |
|
|
| if rank != 0: |
| return |
|
|
| merged.sort(key=lambda r: r["sample_index"]) |
| output_path = resolve_repo_path( |
| os.environ.get( |
| "OUTPUT_PATH", |
| "artifacts/eval/math_level1_thinking_zeroshot/answers.jsonl", |
| ) |
| ) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| with output_path.open("w", encoding="utf-8") as handle: |
| for row in merged: |
| handle.write(json.dumps(row, ensure_ascii=True) + "\n") |
|
|
| acc = sum(r["correctness"] for r in merged) / len(merged) if merged else 0.0 |
| acc_strict = ( |
| sum(r["correctness_strict_boxed"] for r in merged) / len(merged) |
| if merged |
| else 0.0 |
| ) |
| print(f"Wrote {len(merged)} rows to {output_path}") |
| print(f"Accuracy (lenient numeric): {acc:.4f}") |
| print(f"Accuracy (strict boxed): {acc_strict:.4f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|