| |
|
| | """
|
| | Benchmark Circuit transformer family against standard LM tasks.
|
| |
|
| | Usage:
|
| | # Single model
|
| | python -m circuits.bench --checkpoint circuits/checkpoints/slot_local_mirrored/best.pt --gpu 0
|
| |
|
| | # Compare all architectures
|
| | python -m circuits.bench --compare --gpu 0
|
| |
|
| | # Quick sanity check (100 samples per task)
|
| | python -m circuits.bench --compare --gpu 0 --limit 100
|
| |
|
| | # Specific tasks
|
| | python -m circuits.bench --checkpoint path/to/best.pt --tasks hellaswag,lambada_openai
|
| | """
|
| |
|
| | import argparse
|
| | import json
|
| | import time
|
| | import torch
|
| | from pathlib import Path
|
| |
|
| | import lm_eval
|
| | from lm_eval.api.registry import register_model
|
| |
|
| | from .lm_eval_wrapper import CircuitLM
|
| |
|
| |
|
| | register_model("circuit")(CircuitLM)
|
| |
|
| | DEFAULT_TASKS = "arc_challenge,arc_easy,boolq,hellaswag,lambada_openai,piqa,wikitext,winogrande"
|
| |
|
| |
|
| | CHECKPOINTS = {
|
| | "standard_12L": "circuits/checkpoints/flat/best.pt",
|
| | "mirrored_9L_wide": "circuits/checkpoints/hier_wide_2/best.pt",
|
| | "mirrored_15L_deep": "circuits/checkpoints/hier_resized/best.pt",
|
| | "slot_local_mirrored": "circuits/checkpoints/slot_local_mirrored/best.pt",
|
| | }
|
| |
|
| |
|
| | def run_benchmark(checkpoint: str, tasks: str, device: str, limit: int = None, batch_size: int = 1, compile: bool = False):
|
| | """Run lm-eval on a single checkpoint."""
|
| | model_args = f"checkpoint={checkpoint},device={device},batch_size={batch_size},compile={'true' if compile else 'false'}"
|
| |
|
| | task_list = tasks.split(",")
|
| |
|
| | results = lm_eval.simple_evaluate(
|
| | model="circuit",
|
| | model_args=model_args,
|
| | tasks=task_list,
|
| | limit=limit,
|
| | )
|
| |
|
| | return results
|
| |
|
| |
|
| | def extract_scores(results: dict) -> dict:
|
| | """Pull headline metrics from lm-eval results."""
|
| | scores = {}
|
| | if "results" not in results:
|
| | return scores
|
| | for task_name, task_results in results["results"].items():
|
| |
|
| | if "acc_norm,none" in task_results:
|
| | scores[task_name] = task_results["acc_norm,none"]
|
| | elif "acc,none" in task_results:
|
| | scores[task_name] = task_results["acc,none"]
|
| | elif "perplexity,none" in task_results:
|
| | scores[task_name] = task_results["perplexity,none"]
|
| | elif "word_perplexity,none" in task_results:
|
| | scores[task_name] = task_results["word_perplexity,none"]
|
| | return scores
|
| |
|
| |
|
| | def print_comparison(all_results: dict, tasks: list):
|
| | """Pretty-print comparison table."""
|
| |
|
| | col_width = max(len(t) for t in tasks) + 2
|
| | name_width = max(len(n) for n in all_results) + 2
|
| |
|
| | header = f"{'Model':<{name_width}}"
|
| | for task in tasks:
|
| | header += f"{task:>{col_width}}"
|
| | header += f"{' avg':>8}"
|
| | print("\n" + "=" * len(header))
|
| | print(header)
|
| | print("-" * len(header))
|
| |
|
| | for name, scores in all_results.items():
|
| | row = f"{name:<{name_width}}"
|
| | vals = []
|
| | for task in tasks:
|
| | val = scores.get(task, None)
|
| | if val is not None:
|
| | row += f"{val:>{col_width}.4f}"
|
| | vals.append(val)
|
| | else:
|
| | row += f"{'N/A':>{col_width}}"
|
| | avg = sum(vals) / len(vals) if vals else 0
|
| | row += f"{avg:>8.4f}"
|
| | print(row)
|
| |
|
| | print("=" * len(header))
|
| |
|
| |
|
| | def main():
|
| | parser = argparse.ArgumentParser(description="Benchmark Circuit transformers")
|
| | parser.add_argument("--checkpoint", type=str, help="Path to single checkpoint")
|
| | parser.add_argument("--compare", action="store_true", help="Compare all known architectures")
|
| | parser.add_argument("--tasks", type=str, default=DEFAULT_TASKS, help="Comma-separated task list")
|
| | parser.add_argument("--gpu", type=int, default=0, help="GPU index")
|
| | parser.add_argument("--limit", type=int, default=None, help="Limit samples per task (for quick testing)")
|
| | parser.add_argument("--batch-size", type=int, default=1, help="Batch size")
|
| | parser.add_argument("--output", type=str, default=None, help="Save results to JSON")
|
| | parser.add_argument("--compile", action="store_true", help="torch.compile models for faster inference")
|
| | args = parser.parse_args()
|
| |
|
| | device = f"cuda:{args.gpu}"
|
| | task_list = args.tasks.split(",")
|
| |
|
| | if args.compare:
|
| | all_scores = {}
|
| | all_raw = {}
|
| |
|
| |
|
| | available = {k: v for k, v in CHECKPOINTS.items() if Path(v).exists()}
|
| | missing = {k: v for k, v in CHECKPOINTS.items() if not Path(v).exists()}
|
| | if missing:
|
| | print(f"Skipping (not found): {', '.join(missing.keys())}")
|
| |
|
| | for name, ckpt_path in available.items():
|
| | print(f"\n{'='*60}")
|
| | print(f"Evaluating: {name}")
|
| | print(f"Checkpoint: {ckpt_path}")
|
| | print(f"{'='*60}")
|
| |
|
| | t0 = time.time()
|
| | results = run_benchmark(ckpt_path, args.tasks, device, args.limit, args.batch_size, args.compile)
|
| | elapsed = time.time() - t0
|
| |
|
| | scores = extract_scores(results)
|
| | all_scores[name] = scores
|
| | all_raw[name] = results.get("results", {})
|
| | print(f" Completed in {elapsed:.0f}s: {scores}")
|
| |
|
| | print_comparison(all_scores, task_list)
|
| |
|
| | if args.output:
|
| | with open(args.output, "w") as f:
|
| | json.dump({"scores": all_scores, "raw": all_raw}, f, indent=2, default=str)
|
| | print(f"\nResults saved to {args.output}")
|
| |
|
| | elif args.checkpoint:
|
| | print(f"Evaluating: {args.checkpoint}")
|
| | t0 = time.time()
|
| | results = run_benchmark(args.checkpoint, args.tasks, device, args.limit, args.batch_size, args.compile)
|
| | elapsed = time.time() - t0
|
| |
|
| | scores = extract_scores(results)
|
| | print(f"\nResults ({elapsed:.0f}s):")
|
| | for task, score in scores.items():
|
| | print(f" {task}: {score:.4f}")
|
| |
|
| | if args.output:
|
| | with open(args.output, "w") as f:
|
| | json.dump(results, f, indent=2, default=str)
|
| | print(f"\nResults saved to {args.output}")
|
| | else:
|
| | parser.print_help()
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|