Prisma / bench.py
y3i12's picture
Initial commit
56e82ec
#!/usr/bin/env python3
"""
Benchmark Circuit transformer family against standard LM tasks.
Usage:
# Single model
python -m circuits.bench --checkpoint circuits/checkpoints/slot_local_mirrored/best.pt --gpu 0
# Compare all architectures
python -m circuits.bench --compare --gpu 0
# Quick sanity check (100 samples per task)
python -m circuits.bench --compare --gpu 0 --limit 100
# Specific tasks
python -m circuits.bench --checkpoint path/to/best.pt --tasks hellaswag,lambada_openai
"""
import argparse
import json
import time
import torch
from pathlib import Path
import lm_eval
from lm_eval.api.registry import register_model
from .lm_eval_wrapper import CircuitLM
# Register so lm_eval can find it
register_model("circuit")(CircuitLM)
DEFAULT_TASKS = "arc_challenge,arc_easy,boolq,hellaswag,lambada_openai,piqa,wikitext,winogrande"
# Known checkpoints for --compare mode
CHECKPOINTS = {
"standard_12L": "circuits/checkpoints/flat/best.pt",
"mirrored_9L_wide": "circuits/checkpoints/hier_wide_2/best.pt",
"mirrored_15L_deep": "circuits/checkpoints/hier_resized/best.pt",
"slot_local_mirrored": "circuits/checkpoints/slot_local_mirrored/best.pt",
}
def run_benchmark(checkpoint: str, tasks: str, device: str, limit: int = None, batch_size: int = 1, compile: bool = False):
"""Run lm-eval on a single checkpoint."""
model_args = f"checkpoint={checkpoint},device={device},batch_size={batch_size},compile={'true' if compile else 'false'}"
task_list = tasks.split(",")
results = lm_eval.simple_evaluate(
model="circuit",
model_args=model_args,
tasks=task_list,
limit=limit,
)
return results
def extract_scores(results: dict) -> dict:
"""Pull headline metrics from lm-eval results."""
scores = {}
if "results" not in results:
return scores
for task_name, task_results in results["results"].items():
# Get the primary metric (usually acc or acc_norm)
if "acc_norm,none" in task_results:
scores[task_name] = task_results["acc_norm,none"]
elif "acc,none" in task_results:
scores[task_name] = task_results["acc,none"]
elif "perplexity,none" in task_results:
scores[task_name] = task_results["perplexity,none"]
elif "word_perplexity,none" in task_results:
scores[task_name] = task_results["word_perplexity,none"]
return scores
def print_comparison(all_results: dict, tasks: list):
"""Pretty-print comparison table."""
# Header
col_width = max(len(t) for t in tasks) + 2
name_width = max(len(n) for n in all_results) + 2
header = f"{'Model':<{name_width}}"
for task in tasks:
header += f"{task:>{col_width}}"
header += f"{' avg':>8}"
print("\n" + "=" * len(header))
print(header)
print("-" * len(header))
for name, scores in all_results.items():
row = f"{name:<{name_width}}"
vals = []
for task in tasks:
val = scores.get(task, None)
if val is not None:
row += f"{val:>{col_width}.4f}"
vals.append(val)
else:
row += f"{'N/A':>{col_width}}"
avg = sum(vals) / len(vals) if vals else 0
row += f"{avg:>8.4f}"
print(row)
print("=" * len(header))
def main():
parser = argparse.ArgumentParser(description="Benchmark Circuit transformers")
parser.add_argument("--checkpoint", type=str, help="Path to single checkpoint")
parser.add_argument("--compare", action="store_true", help="Compare all known architectures")
parser.add_argument("--tasks", type=str, default=DEFAULT_TASKS, help="Comma-separated task list")
parser.add_argument("--gpu", type=int, default=0, help="GPU index")
parser.add_argument("--limit", type=int, default=None, help="Limit samples per task (for quick testing)")
parser.add_argument("--batch-size", type=int, default=1, help="Batch size")
parser.add_argument("--output", type=str, default=None, help="Save results to JSON")
parser.add_argument("--compile", action="store_true", help="torch.compile models for faster inference")
args = parser.parse_args()
device = f"cuda:{args.gpu}"
task_list = args.tasks.split(",")
if args.compare:
all_scores = {}
all_raw = {}
# Filter to existing checkpoints
available = {k: v for k, v in CHECKPOINTS.items() if Path(v).exists()}
missing = {k: v for k, v in CHECKPOINTS.items() if not Path(v).exists()}
if missing:
print(f"Skipping (not found): {', '.join(missing.keys())}")
for name, ckpt_path in available.items():
print(f"\n{'='*60}")
print(f"Evaluating: {name}")
print(f"Checkpoint: {ckpt_path}")
print(f"{'='*60}")
t0 = time.time()
results = run_benchmark(ckpt_path, args.tasks, device, args.limit, args.batch_size, args.compile)
elapsed = time.time() - t0
scores = extract_scores(results)
all_scores[name] = scores
all_raw[name] = results.get("results", {})
print(f" Completed in {elapsed:.0f}s: {scores}")
print_comparison(all_scores, task_list)
if args.output:
with open(args.output, "w") as f:
json.dump({"scores": all_scores, "raw": all_raw}, f, indent=2, default=str)
print(f"\nResults saved to {args.output}")
elif args.checkpoint:
print(f"Evaluating: {args.checkpoint}")
t0 = time.time()
results = run_benchmark(args.checkpoint, args.tasks, device, args.limit, args.batch_size, args.compile)
elapsed = time.time() - t0
scores = extract_scores(results)
print(f"\nResults ({elapsed:.0f}s):")
for task, score in scores.items():
print(f" {task}: {score:.4f}")
if args.output:
with open(args.output, "w") as f:
json.dump(results, f, indent=2, default=str)
print(f"\nResults saved to {args.output}")
else:
parser.print_help()
if __name__ == "__main__":
main()