from __future__ import annotations import csv import json import os import re from pathlib import Path import torch import torch.distributed as dist import yaml from transformers import AutoModelForCausalLM, AutoTokenizer import hackable # noqa: F401 from hackable.data_plugins import GSM8KProvider from hackable.paths import resolve_storage_path, storage_layout from hackable.reward_plugins import gsm8k_correctness_reward from hackable.utils import resolve_repo_path THINK_RE = re.compile(r"(.*?)", re.DOTALL) def _load_yaml(path: str) -> dict: with open(path, "r", encoding="utf-8") as handle: return yaml.safe_load(handle) def _cot_word_len(completion: str) -> int: match = THINK_RE.search(completion) text = match.group(1).strip() if match else "" return len(text.split()) if text else 0 def _model_dtype(cfg: dict): return torch.bfloat16 if bool(cfg.get("trainer", {}).get("bf16", True)) else torch.float16 def _get_cache_paths(base_cfg: dict) -> tuple[Path, Path]: layout = storage_layout(base_cfg.get("storage", {}).get("cache_dir", "cache")) return layout.datasets, layout.models def _dist_info() -> tuple[int, int, int]: rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) return rank, world_size, local_rank def _init_distributed() -> tuple[int, int, int]: rank, world_size, local_rank = _dist_info() if world_size > 1 and not dist.is_initialized(): backend = "nccl" if torch.cuda.is_available() else "gloo" dist.init_process_group(backend=backend, init_method="env://") return rank, world_size, local_rank def _resolve_local_model_dir(base_cfg: dict, model_dir: str) -> Path: candidate = Path(model_dir) if candidate.is_absolute() and candidate.exists(): return candidate.resolve() if not candidate.is_absolute() and candidate.exists(): return candidate.resolve() repo_local = resolve_repo_path(model_dir) if repo_local.exists(): return repo_local cache_root = resolve_repo_path(base_cfg.get("storage", {}).get("cache_dir", "cache")) prefixed = (cache_root / candidate).resolve() if prefixed.exists(): return prefixed raise FileNotFoundError( f"Model directory not found locally: '{model_dir}'. " f"Tried '{candidate}', '{repo_local}', and '{prefixed}'." ) def _resolve_sweep_root(base_cfg: dict, requested_sweep_root: Path) -> Path: candidate = resolve_storage_path( requested_sweep_root, base_cfg.get("storage", {}).get("cache_dir", "cache"), ) if candidate.is_dir() and any(path.is_dir() and path.name.startswith("run_") for path in candidate.iterdir()): return candidate raise FileNotFoundError( "Could not resolve SWEEP_ROOT with run directories: " f"{candidate}" ) def _discover_model_dirs(sweep_root: Path) -> list[Path]: dirs = [ path for path in sweep_root.iterdir() if path.is_dir() and path.name.startswith("run_") ] if not dirs: raise FileNotFoundError( f"No run directories starting with 'run_' found in {sweep_root}" ) return sorted(dirs) @torch.no_grad() def evaluate_one_model( model_dir: Path, base_cfg: dict, eval_max_samples: int, batch_size: int, ) -> list[dict]: rank, world_size, local_rank = _dist_info() generation = base_cfg.get("generation", {}) max_prompt_len = int(generation.get("max_prompt_length", 512)) max_completion_len = int(generation.get("max_completion_length", 256)) model_name_fallback = str(base_cfg["model"]["name"]) trust_remote_code = bool(base_cfg.get("model", {}).get("trust_remote_code", False)) dtype = _model_dtype(base_cfg) datasets_cache, models_cache = _get_cache_paths(base_cfg) provider = GSM8KProvider() all_samples = provider.load( split="test", max_samples=None if eval_max_samples < 0 else eval_max_samples, cache_dir=str(datasets_cache), ) indices = list(range(rank, len(all_samples), world_size)) local_samples = [all_samples[idx] for idx in indices] prompts = [sample.prompt for sample in local_samples] refs = [sample.target for sample in local_samples] metadata = [sample.metadata for sample in local_samples] try: tokenizer = AutoTokenizer.from_pretrained( str(model_dir), trust_remote_code=trust_remote_code, cache_dir=str(models_cache), local_files_only=True, ) except Exception: tokenizer = AutoTokenizer.from_pretrained( model_name_fallback, trust_remote_code=trust_remote_code, cache_dir=str(models_cache), local_files_only=True, ) model = AutoModelForCausalLM.from_pretrained( str(model_dir), trust_remote_code=trust_remote_code, cache_dir=str(models_cache), torch_dtype=dtype, local_files_only=True, ) if torch.cuda.is_available(): torch.cuda.set_device(local_rank) device = torch.device(f"cuda:{local_rank}") else: device = torch.device("cpu") model.to(device) model.eval() if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token completions: list[str] = [] for start in range(0, len(prompts), batch_size): batch_prompts = prompts[start : start + batch_size] enc = tokenizer( batch_prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_prompt_len, ) input_ids = enc["input_ids"].to(device) attn = enc["attention_mask"].to(device) out = model.generate( input_ids=input_ids, attention_mask=attn, max_new_tokens=max_completion_len, do_sample=False, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) prompt_lens = attn.sum(dim=1).tolist() for idx in range(out.size(0)): completion_ids = out[idx, int(prompt_lens[idx]) :] completions.append(tokenizer.decode(completion_ids, skip_special_tokens=True)) scores = gsm8k_correctness_reward( prompts=prompts, completions=completions, references=refs, metadata=metadata, ) local_records: list[dict] = [] for i, (prompt, reference, completion, score) in enumerate( zip(prompts, refs, completions, scores, strict=True) ): local_records.append( { "sample_index": int(indices[i]), "prompt": prompt, "reference": reference, "completion": completion, "correctness": float(score), "cot_words": int(_cot_word_len(completion)), } ) del model if torch.cuda.is_available(): torch.cuda.empty_cache() if dist.is_initialized(): gathered: list[list[dict] | None] = [None for _ in range(world_size)] dist.all_gather_object(gathered, local_records) merged: list[dict] = [] for part in gathered: if part: merged.extend(part) else: merged = local_records merged.sort(key=lambda row: row["sample_index"]) return merged def _summarize(records: list[dict], model_dir: str) -> dict: if not records: return { "name": Path(model_dir).name, "model_dir": model_dir, "num_examples": 0, "accuracy": 0.0, "avg_cot_words": 0.0, } accuracy = sum(float(row["correctness"]) for row in records) / len(records) avg_cot = sum(float(row["cot_words"]) for row in records) / len(records) return { "name": Path(model_dir).name, "model_dir": model_dir, "num_examples": len(records), "accuracy": float(accuracy), "avg_cot_words": float(avg_cot), } def _write_accuracy_svg(summaries: list[dict], path: Path) -> None: width = 1000 height = 460 left_margin = 70 right_margin = 30 top_margin = 70 bottom_margin = 90 plot_w = width - left_margin - right_margin plot_h = height - top_margin - bottom_margin y_base = top_margin + plot_h runs = [row["name"] for row in summaries] acc_vals = [float(row["accuracy"]) for row in summaries] vmax = max(1.0, max(acc_vals) if acc_vals else 1.0) bar_count = max(1, len(runs)) slot_w = plot_w / bar_count bar_w = min(120, max(30, int(slot_w * 0.55))) palette = ["#2563eb", "#dc2626", "#16a34a", "#ca8a04", "#7c3aed", "#0891b2"] parts: list[str] = [] parts.append(f'') parts.append('') parts.append( 'Sweep Evaluation: GSM8K Accuracy' ) parts.append( f'' ) parts.append( f'' ) # y-axis ticks for tick in [0.0, 0.25, 0.5, 0.75, 1.0]: y = y_base - int((tick / vmax) * plot_h) if vmax > 0 else y_base parts.append( f'' ) parts.append( f'{tick:.2f}' ) for idx, (run_name, acc) in enumerate(zip(runs, acc_vals, strict=True)): center_x = left_margin + int((idx + 0.5) * slot_w) bar_h = int((acc / vmax) * plot_h) if vmax > 0 else 0 x = center_x - bar_w // 2 y = y_base - bar_h color = palette[idx % len(palette)] parts.append(f'') parts.append( f'{acc:.3f}' ) parts.append( f'{run_name}' ) parts.append("") path.write_text("\n".join(parts), encoding="utf-8") def main() -> None: rank, _, _ = _init_distributed() base_cfg = _load_yaml(str(resolve_repo_path(os.environ["BASE_CONFIG"]))) requested_sweep_root = Path(os.environ["SWEEP_ROOT"]) sweep_root = _resolve_sweep_root(base_cfg, requested_sweep_root) if "OUT_ROOT" in os.environ: out_root = resolve_repo_path(os.environ["OUT_ROOT"]) else: out_root = (sweep_root / "eval_results").resolve() eval_max_samples = int(os.environ.get("EVAL_MAX_SAMPLES", "200")) eval_batch_size = int(os.environ.get("EVAL_BATCH_SIZE", "4")) model_dirs = _discover_model_dirs(sweep_root) resolved_model_dirs = [_resolve_local_model_dir(base_cfg, str(path)) for path in model_dirs] if rank == 0: out_root.mkdir(parents=True, exist_ok=True) (out_root / "outputs").mkdir(parents=True, exist_ok=True) if dist.is_initialized(): dist.barrier() summaries: list[dict] = [] for model_dir in resolved_model_dirs: records = evaluate_one_model( model_dir=model_dir, base_cfg=base_cfg, eval_max_samples=eval_max_samples, batch_size=eval_batch_size, ) if rank == 0: output_jsonl = out_root / "outputs" / f"{model_dir.name}_outputs.jsonl" with output_jsonl.open("w", encoding="utf-8") as handle: for row in records: handle.write(json.dumps(row, ensure_ascii=True) + "\n") summary = _summarize(records, str(model_dir)) summary["outputs_jsonl"] = str(output_jsonl) summaries.append(summary) if dist.is_initialized(): dist.barrier() if rank != 0: return json_path = out_root / "sweep_eval_summary.json" csv_path = out_root / "sweep_eval_summary.csv" svg_path = out_root / "sweep_eval_accuracy.svg" json_path.write_text(json.dumps(summaries, indent=2), encoding="utf-8") with csv_path.open("w", encoding="utf-8", newline="") as handle: writer = csv.DictWriter( handle, fieldnames=[ "name", "model_dir", "num_examples", "accuracy", "avg_cot_words", "outputs_jsonl", ], ) writer.writeheader() for row in summaries: writer.writerow(row) _write_accuracy_svg(summaries, svg_path) print(f"Saved summary: {json_path}") print(f"Saved summary: {csv_path}") print(f"Saved plot: {svg_path}") print(f"Saved outputs dir: {out_root / 'outputs'}") if __name__ == "__main__": main()