cortex / benchmark /runner.py
theapemachine's picture
Enhance benchmark and Cortex modules with new training utilities and improved state management. Update README with example output for Llama-3.2-1B and add training CLI for Cortex module tuning. Refactor scoring functions to reset Cortex state between examples and ensure consistent output. Modify task handling to ensure proper formatting of input data.
0de2901
"""
Benchmark Runner: Orchestrates base vs Cortex-enhanced model comparison.
Usage:
runner = BenchmarkRunner(model_name="HuggingFaceTB/SmolLM2-135M")
results = runner.run_comparison(tasks=["hellaswag", "piqa"], n=50)
runner.print_results(results)
"""
import sys
import os
import time
import json
import torch
from typing import Dict, List, Optional, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
# Add parent dir so cortex can be imported
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from benchmark.scoring import log_likelihood_score, accuracy_from_loglikelihoods
from benchmark.tasks import TASK_REGISTRY, BenchmarkTask
from benchmark.memory_tasks import PasskeyRetrieval, MultiHopMemory
from cortex.torch_device import resolve_torch_device
class BenchmarkRunner:
"""
Runs a full comparison between base model and Cortex-enhanced model.
Workflow:
1. Load base model, run all tasks → base results
2. Inject Cortex modules via CortexSurgeon → enhanced model
3. Run all tasks again → cortex results
4. Compare and report
"""
def __init__(
self,
model_name: str = "HuggingFaceTB/SmolLM2-135M",
device: str = "auto",
dtype: str = "float32",
cortex_weights: Optional[str] = None,
):
self.model_name = model_name
self.cortex_weights = cortex_weights
if device == "auto":
self.device = resolve_torch_device("auto")
else:
self.device = device
self.dtype = getattr(torch, dtype)
print(f"Loading model: {model_name} on {self.device} ({dtype})")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=self.dtype,
device_map=self.device,
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model.eval()
print(f"Model loaded: {self.model.config.hidden_size}d, {self.model.config.num_hidden_layers}L")
def _run_loglikelihood_task(
self,
task: BenchmarkTask,
model,
n: Optional[int] = None,
) -> Dict:
"""Run a log-likelihood scoring task."""
print(f" Loading examples for {task.name}...")
examples = task.load_examples(n=n)
print(f" Scoring {len(examples)} examples...")
scores_and_golds = []
for i, ex in enumerate(examples):
if (i + 1) % 10 == 0:
print(f" [{i+1}/{len(examples)}]")
scores = log_likelihood_score(
model, self.tokenizer,
ex["context"], ex["continuations"],
device=self.device,
)
scores_and_golds.append((scores, ex["gold_idx"]))
return accuracy_from_loglikelihoods(scores_and_golds)
def _run_memory_tasks(
self,
model,
n_passkey: int = 5,
passkey_lengths: Optional[List[int]] = None,
n_multihop: Optional[int] = None,
) -> Dict:
"""Run memory-specific benchmarks."""
results = {}
# Passkey retrieval
print(" Running passkey retrieval...")
passkey = PasskeyRetrieval(context_lengths=passkey_lengths or [128, 256, 512])
results["passkey_retrieval"] = passkey.run(
model, self.tokenizer,
n_per_length=n_passkey, device=self.device,
)
# Multi-hop memory
print(" Running multi-hop memory...")
multihop = MultiHopMemory()
results["multi_hop_memory"] = multihop.run(
model, self.tokenizer,
n=n_multihop, device=self.device,
)
return results
def inject_cortex(self) -> Dict:
"""
Inject all Cortex modules into the model.
Returns dict with module info.
"""
from cortex import (
CortexSurgeon, MemoryBank, HallucinationGate,
PauseAndThink, BacktrackHead, SteeringVector, AdaptiveDepth,
)
surgeon = CortexSurgeon(self.model)
hidden_dim = surgeon.hidden_dim
num_layers = surgeon.num_layers
# Find valid num_heads for cross-attention
num_heads = 8
while hidden_dim % num_heads != 0 and num_heads > 1:
num_heads -= 1
middle_layers = list(range(num_layers // 3, 2 * num_layers // 3))
deep_layers = list(range(2 * num_layers // 3, num_layers))
surgeon.add_module("memory", MemoryBank(
hidden_dim=hidden_dim, num_slots=32, num_heads=num_heads,
target_layers=middle_layers,
))
surgeon.add_module("halluc_gate", HallucinationGate(
hidden_dim=hidden_dim, bottleneck_dim=32,
target_layers=deep_layers,
))
surgeon.add_module("pause_think", PauseAndThink(
hidden_dim=hidden_dim, num_think_tokens=4,
target_layers=middle_layers,
))
surgeon.add_module("backtrack", BacktrackHead(
hidden_dim=hidden_dim, confidence_bottleneck=32,
num_layers=num_layers, target_layers="all",
))
surgeon.add_module("steering", SteeringVector(
hidden_dim=hidden_dim, num_directions=2,
direction_names=["truthfulness", "helpfulness"],
target_layers=middle_layers,
))
surgeon.add_module("adaptive_depth", AdaptiveDepth(
hidden_dim=hidden_dim, target_layers="all",
))
surgeon.operate(freeze_base=True)
if self.cortex_weights:
surgeon.load_cortex_modules(self.cortex_weights)
print(f" Loaded Cortex weights: {self.cortex_weights}")
report = surgeon.get_parameter_report()
total_cortex = sum(info["trainable"] for info in report.values())
total_model = sum(p.numel() for p in self.model.parameters())
self._surgeon = surgeon
return {
"total_cortex_params": total_cortex,
"total_model_params": total_model,
"overhead_pct": total_cortex / total_model * 100,
"per_module": report,
}
def remove_cortex(self):
"""Remove Cortex modules and restore base model."""
if hasattr(self, "_surgeon"):
self._surgeon.remove_all()
del self._surgeon
def run_comparison(
self,
tasks: Optional[List[str]] = None,
n: int = 50,
include_memory: bool = True,
n_passkey: int = 5,
passkey_lengths: Optional[List[int]] = None,
) -> Dict:
"""
Run full comparison: base model vs Cortex-enhanced.
Args:
tasks: List of task names from TASK_REGISTRY. None = all.
n: Number of examples per task.
include_memory: Whether to run memory benchmarks.
n_passkey: Number of passkey examples per context length.
passkey_lengths: Context lengths for passkey test.
Returns:
Dict with base_results, cortex_results, and comparison.
"""
if tasks is None:
tasks = ["hellaswag", "piqa", "arc-easy", "winogrande"]
results = {
"model": self.model_name,
"device": self.device,
"dtype": str(self.dtype),
"n_per_task": n,
"tasks": tasks,
"base": {},
"cortex": {},
"comparison": {},
}
# ===== BASE MODEL =====
print("\n" + "=" * 60)
print("PHASE 1: BASE MODEL EVALUATION")
print("=" * 60)
for task_name in tasks:
print(f"\n[BASE] Running {task_name}...")
t0 = time.time()
task_cls = TASK_REGISTRY[task_name]
task = task_cls() if callable(task_cls) else task_cls
result = self._run_loglikelihood_task(task, self.model, n=n)
result["time_seconds"] = time.time() - t0
results["base"][task_name] = result
print(f" {task_name}: {result['accuracy']:.4f} ({result['correct']}/{result['total']}) "
f"[{result['time_seconds']:.1f}s]")
if include_memory:
print(f"\n[BASE] Running memory benchmarks...")
t0 = time.time()
mem_results = self._run_memory_tasks(
self.model, n_passkey=n_passkey,
passkey_lengths=passkey_lengths,
)
mem_results["time_seconds"] = time.time() - t0
results["base"]["memory"] = mem_results
pk = mem_results["passkey_retrieval"]["overall"]
mh = mem_results["multi_hop_memory"]
print(f" passkey: {pk['accuracy']:.4f} ({pk['correct']}/{pk['total']})")
print(f" multi_hop: {mh['accuracy']:.4f} ({mh['correct']}/{mh['total']})")
# ===== CORTEX-ENHANCED MODEL =====
print("\n" + "=" * 60)
print("PHASE 2: CORTEX-ENHANCED MODEL EVALUATION")
print("=" * 60)
print("\nInjecting Cortex modules...")
module_info = self.inject_cortex()
print(f" Cortex params: {module_info['total_cortex_params']:,} "
f"({module_info['overhead_pct']:.2f}% overhead)")
results["cortex_info"] = module_info
for task_name in tasks:
print(f"\n[CORTEX] Running {task_name}...")
t0 = time.time()
task_cls = TASK_REGISTRY[task_name]
task = task_cls() if callable(task_cls) else task_cls
result = self._run_loglikelihood_task(task, self.model, n=n)
result["time_seconds"] = time.time() - t0
results["cortex"][task_name] = result
print(f" {task_name}: {result['accuracy']:.4f} ({result['correct']}/{result['total']}) "
f"[{result['time_seconds']:.1f}s]")
if include_memory:
print(f"\n[CORTEX] Running memory benchmarks...")
t0 = time.time()
mem_results = self._run_memory_tasks(
self.model, n_passkey=n_passkey,
passkey_lengths=passkey_lengths,
)
mem_results["time_seconds"] = time.time() - t0
results["cortex"]["memory"] = mem_results
pk = mem_results["passkey_retrieval"]["overall"]
mh = mem_results["multi_hop_memory"]
print(f" passkey: {pk['accuracy']:.4f} ({pk['correct']}/{pk['total']})")
print(f" multi_hop: {mh['accuracy']:.4f} ({mh['correct']}/{mh['total']})")
# ===== COMPARISON =====
print("\n" + "=" * 60)
print("COMPARISON: BASE vs CORTEX")
print("=" * 60)
for task_name in tasks:
base_acc = results["base"][task_name]["accuracy"]
cortex_acc = results["cortex"][task_name]["accuracy"]
delta = cortex_acc - base_acc
symbol = "↑" if delta > 0 else "↓" if delta < 0 else "="
results["comparison"][task_name] = {
"base": base_acc,
"cortex": cortex_acc,
"delta": delta,
}
print(f" {task_name:20s} base={base_acc:.4f} cortex={cortex_acc:.4f} "
f"Δ={delta:+.4f} {symbol}")
if include_memory:
base_pk = results["base"]["memory"]["passkey_retrieval"]["overall"]["accuracy"]
cortex_pk = results["cortex"]["memory"]["passkey_retrieval"]["overall"]["accuracy"]
base_mh = results["base"]["memory"]["multi_hop_memory"]["accuracy"]
cortex_mh = results["cortex"]["memory"]["multi_hop_memory"]["accuracy"]
results["comparison"]["passkey"] = {
"base": base_pk, "cortex": cortex_pk, "delta": cortex_pk - base_pk,
}
results["comparison"]["multi_hop"] = {
"base": base_mh, "cortex": cortex_mh, "delta": cortex_mh - base_mh,
}
print(f" {'passkey':20s} base={base_pk:.4f} cortex={cortex_pk:.4f} "
f"Δ={cortex_pk - base_pk:+.4f}")
print(f" {'multi_hop':20s} base={base_mh:.4f} cortex={cortex_mh:.4f} "
f"Δ={cortex_mh - base_mh:+.4f}")
# Remove cortex modules to restore base model
self.remove_cortex()
return results
@staticmethod
def print_summary(results: Dict):
"""Print a formatted summary of benchmark results."""
print("\n" + "=" * 70)
print(f"BENCHMARK SUMMARY: {results['model']}")
print(f"n={results['n_per_task']} per task, device={results['device']}")
print("=" * 70)
print(f"\n{'Task':22s} {'Base':>8s} {'Cortex':>8s} {'Delta':>8s}")
print("-" * 50)
for task_name, comp in results["comparison"].items():
delta_str = f"{comp['delta']:+.4f}"
symbol = " ↑" if comp["delta"] > 0.001 else " ↓" if comp["delta"] < -0.001 else " "
print(f"{task_name:22s} {comp['base']:8.4f} {comp['cortex']:8.4f} {delta_str:>8s}{symbol}")
if "cortex_info" in results:
info = results["cortex_info"]
print(f"\nCortex overhead: {info['total_cortex_params']:,} params "
f"({info['overhead_pct']:.2f}%)")
print("=" * 70)