#!/usr/bin/env python3 """ Benchmark Circuit transformer family against standard LM tasks. Usage: # Single model python -m circuits.bench --checkpoint circuits/checkpoints/slot_local_mirrored/best.pt --gpu 0 # Compare all architectures python -m circuits.bench --compare --gpu 0 # Quick sanity check (100 samples per task) python -m circuits.bench --compare --gpu 0 --limit 100 # Specific tasks python -m circuits.bench --checkpoint path/to/best.pt --tasks hellaswag,lambada_openai """ import argparse import json import time import torch from pathlib import Path import lm_eval from lm_eval.api.registry import register_model from .lm_eval_wrapper import CircuitLM # Register so lm_eval can find it register_model("circuit")(CircuitLM) DEFAULT_TASKS = "arc_challenge,arc_easy,boolq,hellaswag,lambada_openai,piqa,wikitext,winogrande" # Known checkpoints for --compare mode CHECKPOINTS = { "standard_12L": "circuits/checkpoints/flat/best.pt", "mirrored_9L_wide": "circuits/checkpoints/hier_wide_2/best.pt", "mirrored_15L_deep": "circuits/checkpoints/hier_resized/best.pt", "slot_local_mirrored": "circuits/checkpoints/slot_local_mirrored/best.pt", } def run_benchmark(checkpoint: str, tasks: str, device: str, limit: int = None, batch_size: int = 1, compile: bool = False): """Run lm-eval on a single checkpoint.""" model_args = f"checkpoint={checkpoint},device={device},batch_size={batch_size},compile={'true' if compile else 'false'}" task_list = tasks.split(",") results = lm_eval.simple_evaluate( model="circuit", model_args=model_args, tasks=task_list, limit=limit, ) return results def extract_scores(results: dict) -> dict: """Pull headline metrics from lm-eval results.""" scores = {} if "results" not in results: return scores for task_name, task_results in results["results"].items(): # Get the primary metric (usually acc or acc_norm) if "acc_norm,none" in task_results: scores[task_name] = task_results["acc_norm,none"] elif "acc,none" in task_results: scores[task_name] = task_results["acc,none"] elif "perplexity,none" in task_results: scores[task_name] = task_results["perplexity,none"] elif "word_perplexity,none" in task_results: scores[task_name] = task_results["word_perplexity,none"] return scores def print_comparison(all_results: dict, tasks: list): """Pretty-print comparison table.""" # Header col_width = max(len(t) for t in tasks) + 2 name_width = max(len(n) for n in all_results) + 2 header = f"{'Model':<{name_width}}" for task in tasks: header += f"{task:>{col_width}}" header += f"{' avg':>8}" print("\n" + "=" * len(header)) print(header) print("-" * len(header)) for name, scores in all_results.items(): row = f"{name:<{name_width}}" vals = [] for task in tasks: val = scores.get(task, None) if val is not None: row += f"{val:>{col_width}.4f}" vals.append(val) else: row += f"{'N/A':>{col_width}}" avg = sum(vals) / len(vals) if vals else 0 row += f"{avg:>8.4f}" print(row) print("=" * len(header)) def main(): parser = argparse.ArgumentParser(description="Benchmark Circuit transformers") parser.add_argument("--checkpoint", type=str, help="Path to single checkpoint") parser.add_argument("--compare", action="store_true", help="Compare all known architectures") parser.add_argument("--tasks", type=str, default=DEFAULT_TASKS, help="Comma-separated task list") parser.add_argument("--gpu", type=int, default=0, help="GPU index") parser.add_argument("--limit", type=int, default=None, help="Limit samples per task (for quick testing)") parser.add_argument("--batch-size", type=int, default=1, help="Batch size") parser.add_argument("--output", type=str, default=None, help="Save results to JSON") parser.add_argument("--compile", action="store_true", help="torch.compile models for faster inference") args = parser.parse_args() device = f"cuda:{args.gpu}" task_list = args.tasks.split(",") if args.compare: all_scores = {} all_raw = {} # Filter to existing checkpoints available = {k: v for k, v in CHECKPOINTS.items() if Path(v).exists()} missing = {k: v for k, v in CHECKPOINTS.items() if not Path(v).exists()} if missing: print(f"Skipping (not found): {', '.join(missing.keys())}") for name, ckpt_path in available.items(): print(f"\n{'='*60}") print(f"Evaluating: {name}") print(f"Checkpoint: {ckpt_path}") print(f"{'='*60}") t0 = time.time() results = run_benchmark(ckpt_path, args.tasks, device, args.limit, args.batch_size, args.compile) elapsed = time.time() - t0 scores = extract_scores(results) all_scores[name] = scores all_raw[name] = results.get("results", {}) print(f" Completed in {elapsed:.0f}s: {scores}") print_comparison(all_scores, task_list) if args.output: with open(args.output, "w") as f: json.dump({"scores": all_scores, "raw": all_raw}, f, indent=2, default=str) print(f"\nResults saved to {args.output}") elif args.checkpoint: print(f"Evaluating: {args.checkpoint}") t0 = time.time() results = run_benchmark(args.checkpoint, args.tasks, device, args.limit, args.batch_size, args.compile) elapsed = time.time() - t0 scores = extract_scores(results) print(f"\nResults ({elapsed:.0f}s):") for task, score in scores.items(): print(f" {task}: {score:.4f}") if args.output: with open(args.output, "w") as f: json.dump(results, f, indent=2, default=str) print(f"\nResults saved to {args.output}") else: parser.print_help() if __name__ == "__main__": main()