from __future__ import annotations
import csv
import json
import os
import re
from pathlib import Path
import torch
import torch.distributed as dist
import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer
import hackable # noqa: F401
from hackable.data_plugins import GSM8KProvider
from hackable.paths import resolve_storage_path, storage_layout
from hackable.reward_plugins import gsm8k_correctness_reward
from hackable.utils import resolve_repo_path
THINK_RE = re.compile(r"(.*?)", re.DOTALL)
def _load_yaml(path: str) -> dict:
with open(path, "r", encoding="utf-8") as handle:
return yaml.safe_load(handle)
def _cot_word_len(completion: str) -> int:
match = THINK_RE.search(completion)
text = match.group(1).strip() if match else ""
return len(text.split()) if text else 0
def _model_dtype(cfg: dict):
return torch.bfloat16 if bool(cfg.get("trainer", {}).get("bf16", True)) else torch.float16
def _get_cache_paths(base_cfg: dict) -> tuple[Path, Path]:
layout = storage_layout(base_cfg.get("storage", {}).get("cache_dir", "cache"))
return layout.datasets, layout.models
def _dist_info() -> tuple[int, int, int]:
rank = int(os.environ.get("RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
return rank, world_size, local_rank
def _init_distributed() -> tuple[int, int, int]:
rank, world_size, local_rank = _dist_info()
if world_size > 1 and not dist.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(backend=backend, init_method="env://")
return rank, world_size, local_rank
def _resolve_local_model_dir(base_cfg: dict, model_dir: str) -> Path:
candidate = Path(model_dir)
if candidate.is_absolute() and candidate.exists():
return candidate.resolve()
if not candidate.is_absolute() and candidate.exists():
return candidate.resolve()
repo_local = resolve_repo_path(model_dir)
if repo_local.exists():
return repo_local
cache_root = resolve_repo_path(base_cfg.get("storage", {}).get("cache_dir", "cache"))
prefixed = (cache_root / candidate).resolve()
if prefixed.exists():
return prefixed
raise FileNotFoundError(
f"Model directory not found locally: '{model_dir}'. "
f"Tried '{candidate}', '{repo_local}', and '{prefixed}'."
)
def _resolve_sweep_root(base_cfg: dict, requested_sweep_root: Path) -> Path:
candidate = resolve_storage_path(
requested_sweep_root,
base_cfg.get("storage", {}).get("cache_dir", "cache"),
)
if candidate.is_dir() and any(path.is_dir() and path.name.startswith("run_") for path in candidate.iterdir()):
return candidate
raise FileNotFoundError(
"Could not resolve SWEEP_ROOT with run directories: "
f"{candidate}"
)
def _discover_model_dirs(sweep_root: Path) -> list[Path]:
dirs = [
path
for path in sweep_root.iterdir()
if path.is_dir() and path.name.startswith("run_")
]
if not dirs:
raise FileNotFoundError(
f"No run directories starting with 'run_' found in {sweep_root}"
)
return sorted(dirs)
@torch.no_grad()
def evaluate_one_model(
model_dir: Path,
base_cfg: dict,
eval_max_samples: int,
batch_size: int,
) -> list[dict]:
rank, world_size, local_rank = _dist_info()
generation = base_cfg.get("generation", {})
max_prompt_len = int(generation.get("max_prompt_length", 512))
max_completion_len = int(generation.get("max_completion_length", 256))
model_name_fallback = str(base_cfg["model"]["name"])
trust_remote_code = bool(base_cfg.get("model", {}).get("trust_remote_code", False))
dtype = _model_dtype(base_cfg)
datasets_cache, models_cache = _get_cache_paths(base_cfg)
provider = GSM8KProvider()
all_samples = provider.load(
split="test",
max_samples=None if eval_max_samples < 0 else eval_max_samples,
cache_dir=str(datasets_cache),
)
indices = list(range(rank, len(all_samples), world_size))
local_samples = [all_samples[idx] for idx in indices]
prompts = [sample.prompt for sample in local_samples]
refs = [sample.target for sample in local_samples]
metadata = [sample.metadata for sample in local_samples]
try:
tokenizer = AutoTokenizer.from_pretrained(
str(model_dir),
trust_remote_code=trust_remote_code,
cache_dir=str(models_cache),
local_files_only=True,
)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_name_fallback,
trust_remote_code=trust_remote_code,
cache_dir=str(models_cache),
local_files_only=True,
)
model = AutoModelForCausalLM.from_pretrained(
str(model_dir),
trust_remote_code=trust_remote_code,
cache_dir=str(models_cache),
torch_dtype=dtype,
local_files_only=True,
)
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
else:
device = torch.device("cpu")
model.to(device)
model.eval()
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
completions: list[str] = []
for start in range(0, len(prompts), batch_size):
batch_prompts = prompts[start : start + batch_size]
enc = tokenizer(
batch_prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_prompt_len,
)
input_ids = enc["input_ids"].to(device)
attn = enc["attention_mask"].to(device)
out = model.generate(
input_ids=input_ids,
attention_mask=attn,
max_new_tokens=max_completion_len,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
prompt_lens = attn.sum(dim=1).tolist()
for idx in range(out.size(0)):
completion_ids = out[idx, int(prompt_lens[idx]) :]
completions.append(tokenizer.decode(completion_ids, skip_special_tokens=True))
scores = gsm8k_correctness_reward(
prompts=prompts,
completions=completions,
references=refs,
metadata=metadata,
)
local_records: list[dict] = []
for i, (prompt, reference, completion, score) in enumerate(
zip(prompts, refs, completions, scores, strict=True)
):
local_records.append(
{
"sample_index": int(indices[i]),
"prompt": prompt,
"reference": reference,
"completion": completion,
"correctness": float(score),
"cot_words": int(_cot_word_len(completion)),
}
)
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
if dist.is_initialized():
gathered: list[list[dict] | None] = [None for _ in range(world_size)]
dist.all_gather_object(gathered, local_records)
merged: list[dict] = []
for part in gathered:
if part:
merged.extend(part)
else:
merged = local_records
merged.sort(key=lambda row: row["sample_index"])
return merged
def _summarize(records: list[dict], model_dir: str) -> dict:
if not records:
return {
"name": Path(model_dir).name,
"model_dir": model_dir,
"num_examples": 0,
"accuracy": 0.0,
"avg_cot_words": 0.0,
}
accuracy = sum(float(row["correctness"]) for row in records) / len(records)
avg_cot = sum(float(row["cot_words"]) for row in records) / len(records)
return {
"name": Path(model_dir).name,
"model_dir": model_dir,
"num_examples": len(records),
"accuracy": float(accuracy),
"avg_cot_words": float(avg_cot),
}
def _write_accuracy_svg(summaries: list[dict], path: Path) -> None:
width = 1000
height = 460
left_margin = 70
right_margin = 30
top_margin = 70
bottom_margin = 90
plot_w = width - left_margin - right_margin
plot_h = height - top_margin - bottom_margin
y_base = top_margin + plot_h
runs = [row["name"] for row in summaries]
acc_vals = [float(row["accuracy"]) for row in summaries]
vmax = max(1.0, max(acc_vals) if acc_vals else 1.0)
bar_count = max(1, len(runs))
slot_w = plot_w / bar_count
bar_w = min(120, max(30, int(slot_w * 0.55)))
palette = ["#2563eb", "#dc2626", "#16a34a", "#ca8a04", "#7c3aed", "#0891b2"]
parts: list[str] = []
parts.append(f'")
path.write_text("\n".join(parts), encoding="utf-8")
def main() -> None:
rank, _, _ = _init_distributed()
base_cfg = _load_yaml(str(resolve_repo_path(os.environ["BASE_CONFIG"])))
requested_sweep_root = Path(os.environ["SWEEP_ROOT"])
sweep_root = _resolve_sweep_root(base_cfg, requested_sweep_root)
if "OUT_ROOT" in os.environ:
out_root = resolve_repo_path(os.environ["OUT_ROOT"])
else:
out_root = (sweep_root / "eval_results").resolve()
eval_max_samples = int(os.environ.get("EVAL_MAX_SAMPLES", "200"))
eval_batch_size = int(os.environ.get("EVAL_BATCH_SIZE", "4"))
model_dirs = _discover_model_dirs(sweep_root)
resolved_model_dirs = [_resolve_local_model_dir(base_cfg, str(path)) for path in model_dirs]
if rank == 0:
out_root.mkdir(parents=True, exist_ok=True)
(out_root / "outputs").mkdir(parents=True, exist_ok=True)
if dist.is_initialized():
dist.barrier()
summaries: list[dict] = []
for model_dir in resolved_model_dirs:
records = evaluate_one_model(
model_dir=model_dir,
base_cfg=base_cfg,
eval_max_samples=eval_max_samples,
batch_size=eval_batch_size,
)
if rank == 0:
output_jsonl = out_root / "outputs" / f"{model_dir.name}_outputs.jsonl"
with output_jsonl.open("w", encoding="utf-8") as handle:
for row in records:
handle.write(json.dumps(row, ensure_ascii=True) + "\n")
summary = _summarize(records, str(model_dir))
summary["outputs_jsonl"] = str(output_jsonl)
summaries.append(summary)
if dist.is_initialized():
dist.barrier()
if rank != 0:
return
json_path = out_root / "sweep_eval_summary.json"
csv_path = out_root / "sweep_eval_summary.csv"
svg_path = out_root / "sweep_eval_accuracy.svg"
json_path.write_text(json.dumps(summaries, indent=2), encoding="utf-8")
with csv_path.open("w", encoding="utf-8", newline="") as handle:
writer = csv.DictWriter(
handle,
fieldnames=[
"name",
"model_dir",
"num_examples",
"accuracy",
"avg_cot_words",
"outputs_jsonl",
],
)
writer.writeheader()
for row in summaries:
writer.writerow(row)
_write_accuracy_svg(summaries, svg_path)
print(f"Saved summary: {json_path}")
print(f"Saved summary: {csv_path}")
print(f"Saved plot: {svg_path}")
print(f"Saved outputs dir: {out_root / 'outputs'}")
if __name__ == "__main__":
main()