| from __future__ import annotations |
|
|
| import csv |
| import json |
| import os |
| import re |
| from pathlib import Path |
|
|
| import torch |
| import torch.distributed as dist |
| import yaml |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| import hackable |
| from hackable.data_plugins import GSM8KProvider |
| from hackable.paths import resolve_storage_path, storage_layout |
| from hackable.reward_plugins import gsm8k_correctness_reward |
| from hackable.utils import resolve_repo_path |
|
|
| THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL) |
|
|
|
|
| def _load_yaml(path: str) -> dict: |
| with open(path, "r", encoding="utf-8") as handle: |
| return yaml.safe_load(handle) |
|
|
|
|
| def _cot_word_len(completion: str) -> int: |
| match = THINK_RE.search(completion) |
| text = match.group(1).strip() if match else "" |
| return len(text.split()) if text else 0 |
|
|
|
|
| def _model_dtype(cfg: dict): |
| return torch.bfloat16 if bool(cfg.get("trainer", {}).get("bf16", True)) else torch.float16 |
|
|
|
|
| def _get_cache_paths(base_cfg: dict) -> tuple[Path, Path]: |
| layout = storage_layout(base_cfg.get("storage", {}).get("cache_dir", "cache")) |
| return layout.datasets, layout.models |
|
|
|
|
| def _dist_info() -> tuple[int, int, int]: |
| rank = int(os.environ.get("RANK", "0")) |
| world_size = int(os.environ.get("WORLD_SIZE", "1")) |
| local_rank = int(os.environ.get("LOCAL_RANK", "0")) |
| return rank, world_size, local_rank |
|
|
|
|
| def _init_distributed() -> tuple[int, int, int]: |
| rank, world_size, local_rank = _dist_info() |
| if world_size > 1 and not dist.is_initialized(): |
| backend = "nccl" if torch.cuda.is_available() else "gloo" |
| dist.init_process_group(backend=backend, init_method="env://") |
| return rank, world_size, local_rank |
|
|
|
|
| def _resolve_local_model_dir(base_cfg: dict, model_dir: str) -> Path: |
| candidate = Path(model_dir) |
| if candidate.is_absolute() and candidate.exists(): |
| return candidate.resolve() |
| if not candidate.is_absolute() and candidate.exists(): |
| return candidate.resolve() |
|
|
| repo_local = resolve_repo_path(model_dir) |
| if repo_local.exists(): |
| return repo_local |
|
|
| cache_root = resolve_repo_path(base_cfg.get("storage", {}).get("cache_dir", "cache")) |
| prefixed = (cache_root / candidate).resolve() |
| if prefixed.exists(): |
| return prefixed |
|
|
| raise FileNotFoundError( |
| f"Model directory not found locally: '{model_dir}'. " |
| f"Tried '{candidate}', '{repo_local}', and '{prefixed}'." |
| ) |
|
|
|
|
| def _resolve_sweep_root(base_cfg: dict, requested_sweep_root: Path) -> Path: |
| candidate = resolve_storage_path( |
| requested_sweep_root, |
| base_cfg.get("storage", {}).get("cache_dir", "cache"), |
| ) |
| if candidate.is_dir() and any(path.is_dir() and path.name.startswith("run_") for path in candidate.iterdir()): |
| return candidate |
| raise FileNotFoundError( |
| "Could not resolve SWEEP_ROOT with run directories: " |
| f"{candidate}" |
| ) |
|
|
|
|
| def _discover_model_dirs(sweep_root: Path) -> list[Path]: |
| dirs = [ |
| path |
| for path in sweep_root.iterdir() |
| if path.is_dir() and path.name.startswith("run_") |
| ] |
| if not dirs: |
| raise FileNotFoundError( |
| f"No run directories starting with 'run_' found in {sweep_root}" |
| ) |
| return sorted(dirs) |
|
|
|
|
| @torch.no_grad() |
| def evaluate_one_model( |
| model_dir: Path, |
| base_cfg: dict, |
| eval_max_samples: int, |
| batch_size: int, |
| ) -> list[dict]: |
| rank, world_size, local_rank = _dist_info() |
| generation = base_cfg.get("generation", {}) |
| max_prompt_len = int(generation.get("max_prompt_length", 512)) |
| max_completion_len = int(generation.get("max_completion_length", 256)) |
| model_name_fallback = str(base_cfg["model"]["name"]) |
| trust_remote_code = bool(base_cfg.get("model", {}).get("trust_remote_code", False)) |
| dtype = _model_dtype(base_cfg) |
| datasets_cache, models_cache = _get_cache_paths(base_cfg) |
|
|
| provider = GSM8KProvider() |
| all_samples = provider.load( |
| split="test", |
| max_samples=None if eval_max_samples < 0 else eval_max_samples, |
| cache_dir=str(datasets_cache), |
| ) |
| indices = list(range(rank, len(all_samples), world_size)) |
| local_samples = [all_samples[idx] for idx in indices] |
| prompts = [sample.prompt for sample in local_samples] |
| refs = [sample.target for sample in local_samples] |
| metadata = [sample.metadata for sample in local_samples] |
|
|
| try: |
| tokenizer = AutoTokenizer.from_pretrained( |
| str(model_dir), |
| trust_remote_code=trust_remote_code, |
| cache_dir=str(models_cache), |
| local_files_only=True, |
| ) |
| except Exception: |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_name_fallback, |
| trust_remote_code=trust_remote_code, |
| cache_dir=str(models_cache), |
| local_files_only=True, |
| ) |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| str(model_dir), |
| trust_remote_code=trust_remote_code, |
| cache_dir=str(models_cache), |
| torch_dtype=dtype, |
| local_files_only=True, |
| ) |
| if torch.cuda.is_available(): |
| torch.cuda.set_device(local_rank) |
| device = torch.device(f"cuda:{local_rank}") |
| else: |
| device = torch.device("cpu") |
| model.to(device) |
| model.eval() |
|
|
| if tokenizer.pad_token_id is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| completions: list[str] = [] |
| for start in range(0, len(prompts), batch_size): |
| batch_prompts = prompts[start : start + batch_size] |
| enc = tokenizer( |
| batch_prompts, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=max_prompt_len, |
| ) |
| input_ids = enc["input_ids"].to(device) |
| attn = enc["attention_mask"].to(device) |
| out = model.generate( |
| input_ids=input_ids, |
| attention_mask=attn, |
| max_new_tokens=max_completion_len, |
| do_sample=False, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
| prompt_lens = attn.sum(dim=1).tolist() |
| for idx in range(out.size(0)): |
| completion_ids = out[idx, int(prompt_lens[idx]) :] |
| completions.append(tokenizer.decode(completion_ids, skip_special_tokens=True)) |
|
|
| scores = gsm8k_correctness_reward( |
| prompts=prompts, |
| completions=completions, |
| references=refs, |
| metadata=metadata, |
| ) |
|
|
| local_records: list[dict] = [] |
| for i, (prompt, reference, completion, score) in enumerate( |
| zip(prompts, refs, completions, scores, strict=True) |
| ): |
| local_records.append( |
| { |
| "sample_index": int(indices[i]), |
| "prompt": prompt, |
| "reference": reference, |
| "completion": completion, |
| "correctness": float(score), |
| "cot_words": int(_cot_word_len(completion)), |
| } |
| ) |
|
|
| del model |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| if dist.is_initialized(): |
| gathered: list[list[dict] | None] = [None for _ in range(world_size)] |
| dist.all_gather_object(gathered, local_records) |
| merged: list[dict] = [] |
| for part in gathered: |
| if part: |
| merged.extend(part) |
| else: |
| merged = local_records |
|
|
| merged.sort(key=lambda row: row["sample_index"]) |
| return merged |
|
|
|
|
| def _summarize(records: list[dict], model_dir: str) -> dict: |
| if not records: |
| return { |
| "name": Path(model_dir).name, |
| "model_dir": model_dir, |
| "num_examples": 0, |
| "accuracy": 0.0, |
| "avg_cot_words": 0.0, |
| } |
| accuracy = sum(float(row["correctness"]) for row in records) / len(records) |
| avg_cot = sum(float(row["cot_words"]) for row in records) / len(records) |
| return { |
| "name": Path(model_dir).name, |
| "model_dir": model_dir, |
| "num_examples": len(records), |
| "accuracy": float(accuracy), |
| "avg_cot_words": float(avg_cot), |
| } |
|
|
|
|
| def _write_accuracy_svg(summaries: list[dict], path: Path) -> None: |
| width = 1000 |
| height = 460 |
| left_margin = 70 |
| right_margin = 30 |
| top_margin = 70 |
| bottom_margin = 90 |
| plot_w = width - left_margin - right_margin |
| plot_h = height - top_margin - bottom_margin |
| y_base = top_margin + plot_h |
|
|
| runs = [row["name"] for row in summaries] |
| acc_vals = [float(row["accuracy"]) for row in summaries] |
| vmax = max(1.0, max(acc_vals) if acc_vals else 1.0) |
|
|
| bar_count = max(1, len(runs)) |
| slot_w = plot_w / bar_count |
| bar_w = min(120, max(30, int(slot_w * 0.55))) |
| palette = ["#2563eb", "#dc2626", "#16a34a", "#ca8a04", "#7c3aed", "#0891b2"] |
|
|
| parts: list[str] = [] |
| parts.append(f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}">') |
| parts.append('<rect width="100%" height="100%" fill="#ffffff"/>') |
| parts.append( |
| '<text x="40" y="34" font-size="20" font-family="sans-serif">Sweep Evaluation: GSM8K Accuracy</text>' |
| ) |
| parts.append( |
| f'<line x1="{left_margin}" y1="{y_base}" x2="{left_margin + plot_w}" y2="{y_base}" stroke="#111" stroke-width="2" />' |
| ) |
| parts.append( |
| f'<line x1="{left_margin}" y1="{top_margin}" x2="{left_margin}" y2="{y_base}" stroke="#111" stroke-width="2" />' |
| ) |
|
|
| |
| for tick in [0.0, 0.25, 0.5, 0.75, 1.0]: |
| y = y_base - int((tick / vmax) * plot_h) if vmax > 0 else y_base |
| parts.append( |
| f'<line x1="{left_margin - 6}" y1="{y}" x2="{left_margin}" y2="{y}" stroke="#111" stroke-width="1" />' |
| ) |
| parts.append( |
| f'<text x="{left_margin - 10}" y="{y + 4}" text-anchor="end" font-size="11" font-family="sans-serif">{tick:.2f}</text>' |
| ) |
|
|
| for idx, (run_name, acc) in enumerate(zip(runs, acc_vals, strict=True)): |
| center_x = left_margin + int((idx + 0.5) * slot_w) |
| bar_h = int((acc / vmax) * plot_h) if vmax > 0 else 0 |
| x = center_x - bar_w // 2 |
| y = y_base - bar_h |
| color = palette[idx % len(palette)] |
| parts.append(f'<rect x="{x}" y="{y}" width="{bar_w}" height="{bar_h}" fill="{color}" />') |
| parts.append( |
| f'<text x="{center_x}" y="{y - 8}" text-anchor="middle" font-size="12" font-family="sans-serif">{acc:.3f}</text>' |
| ) |
| parts.append( |
| f'<text x="{center_x}" y="{y_base + 18}" text-anchor="middle" font-size="11" font-family="sans-serif">{run_name}</text>' |
| ) |
|
|
| parts.append("</svg>") |
| path.write_text("\n".join(parts), encoding="utf-8") |
|
|
|
|
| def main() -> None: |
| rank, _, _ = _init_distributed() |
| base_cfg = _load_yaml(str(resolve_repo_path(os.environ["BASE_CONFIG"]))) |
| requested_sweep_root = Path(os.environ["SWEEP_ROOT"]) |
| sweep_root = _resolve_sweep_root(base_cfg, requested_sweep_root) |
| if "OUT_ROOT" in os.environ: |
| out_root = resolve_repo_path(os.environ["OUT_ROOT"]) |
| else: |
| out_root = (sweep_root / "eval_results").resolve() |
| eval_max_samples = int(os.environ.get("EVAL_MAX_SAMPLES", "200")) |
| eval_batch_size = int(os.environ.get("EVAL_BATCH_SIZE", "4")) |
|
|
| model_dirs = _discover_model_dirs(sweep_root) |
| resolved_model_dirs = [_resolve_local_model_dir(base_cfg, str(path)) for path in model_dirs] |
|
|
| if rank == 0: |
| out_root.mkdir(parents=True, exist_ok=True) |
| (out_root / "outputs").mkdir(parents=True, exist_ok=True) |
|
|
| if dist.is_initialized(): |
| dist.barrier() |
|
|
| summaries: list[dict] = [] |
| for model_dir in resolved_model_dirs: |
| records = evaluate_one_model( |
| model_dir=model_dir, |
| base_cfg=base_cfg, |
| eval_max_samples=eval_max_samples, |
| batch_size=eval_batch_size, |
| ) |
| if rank == 0: |
| output_jsonl = out_root / "outputs" / f"{model_dir.name}_outputs.jsonl" |
| with output_jsonl.open("w", encoding="utf-8") as handle: |
| for row in records: |
| handle.write(json.dumps(row, ensure_ascii=True) + "\n") |
| summary = _summarize(records, str(model_dir)) |
| summary["outputs_jsonl"] = str(output_jsonl) |
| summaries.append(summary) |
|
|
| if dist.is_initialized(): |
| dist.barrier() |
|
|
| if rank != 0: |
| return |
|
|
| json_path = out_root / "sweep_eval_summary.json" |
| csv_path = out_root / "sweep_eval_summary.csv" |
| svg_path = out_root / "sweep_eval_accuracy.svg" |
| json_path.write_text(json.dumps(summaries, indent=2), encoding="utf-8") |
| with csv_path.open("w", encoding="utf-8", newline="") as handle: |
| writer = csv.DictWriter( |
| handle, |
| fieldnames=[ |
| "name", |
| "model_dir", |
| "num_examples", |
| "accuracy", |
| "avg_cot_words", |
| "outputs_jsonl", |
| ], |
| ) |
| writer.writeheader() |
| for row in summaries: |
| writer.writerow(row) |
| _write_accuracy_svg(summaries, svg_path) |
|
|
| print(f"Saved summary: {json_path}") |
| print(f"Saved summary: {csv_path}") |
| print(f"Saved plot: {svg_path}") |
| print(f"Saved outputs dir: {out_root / 'outputs'}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|