Enhance benchmark and Cortex modules with new training utilities and improved state management. Update README with example output for Llama-3.2-1B and add training CLI for Cortex module tuning. Refactor scoring functions to reset Cortex state between examples and ensure consistent output. Modify task handling to ensure proper formatting of input data.
0de2901 | """ | |
| Benchmark Runner: Orchestrates base vs Cortex-enhanced model comparison. | |
| Usage: | |
| runner = BenchmarkRunner(model_name="HuggingFaceTB/SmolLM2-135M") | |
| results = runner.run_comparison(tasks=["hellaswag", "piqa"], n=50) | |
| runner.print_results(results) | |
| """ | |
| import sys | |
| import os | |
| import time | |
| import json | |
| import torch | |
| from typing import Dict, List, Optional, Any | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Add parent dir so cortex can be imported | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from benchmark.scoring import log_likelihood_score, accuracy_from_loglikelihoods | |
| from benchmark.tasks import TASK_REGISTRY, BenchmarkTask | |
| from benchmark.memory_tasks import PasskeyRetrieval, MultiHopMemory | |
| from cortex.torch_device import resolve_torch_device | |
| class BenchmarkRunner: | |
| """ | |
| Runs a full comparison between base model and Cortex-enhanced model. | |
| Workflow: | |
| 1. Load base model, run all tasks → base results | |
| 2. Inject Cortex modules via CortexSurgeon → enhanced model | |
| 3. Run all tasks again → cortex results | |
| 4. Compare and report | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "HuggingFaceTB/SmolLM2-135M", | |
| device: str = "auto", | |
| dtype: str = "float32", | |
| cortex_weights: Optional[str] = None, | |
| ): | |
| self.model_name = model_name | |
| self.cortex_weights = cortex_weights | |
| if device == "auto": | |
| self.device = resolve_torch_device("auto") | |
| else: | |
| self.device = device | |
| self.dtype = getattr(torch, dtype) | |
| print(f"Loading model: {model_name} on {self.device} ({dtype})") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| dtype=self.dtype, | |
| device_map=self.device, | |
| ) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.model.eval() | |
| print(f"Model loaded: {self.model.config.hidden_size}d, {self.model.config.num_hidden_layers}L") | |
| def _run_loglikelihood_task( | |
| self, | |
| task: BenchmarkTask, | |
| model, | |
| n: Optional[int] = None, | |
| ) -> Dict: | |
| """Run a log-likelihood scoring task.""" | |
| print(f" Loading examples for {task.name}...") | |
| examples = task.load_examples(n=n) | |
| print(f" Scoring {len(examples)} examples...") | |
| scores_and_golds = [] | |
| for i, ex in enumerate(examples): | |
| if (i + 1) % 10 == 0: | |
| print(f" [{i+1}/{len(examples)}]") | |
| scores = log_likelihood_score( | |
| model, self.tokenizer, | |
| ex["context"], ex["continuations"], | |
| device=self.device, | |
| ) | |
| scores_and_golds.append((scores, ex["gold_idx"])) | |
| return accuracy_from_loglikelihoods(scores_and_golds) | |
| def _run_memory_tasks( | |
| self, | |
| model, | |
| n_passkey: int = 5, | |
| passkey_lengths: Optional[List[int]] = None, | |
| n_multihop: Optional[int] = None, | |
| ) -> Dict: | |
| """Run memory-specific benchmarks.""" | |
| results = {} | |
| # Passkey retrieval | |
| print(" Running passkey retrieval...") | |
| passkey = PasskeyRetrieval(context_lengths=passkey_lengths or [128, 256, 512]) | |
| results["passkey_retrieval"] = passkey.run( | |
| model, self.tokenizer, | |
| n_per_length=n_passkey, device=self.device, | |
| ) | |
| # Multi-hop memory | |
| print(" Running multi-hop memory...") | |
| multihop = MultiHopMemory() | |
| results["multi_hop_memory"] = multihop.run( | |
| model, self.tokenizer, | |
| n=n_multihop, device=self.device, | |
| ) | |
| return results | |
| def inject_cortex(self) -> Dict: | |
| """ | |
| Inject all Cortex modules into the model. | |
| Returns dict with module info. | |
| """ | |
| from cortex import ( | |
| CortexSurgeon, MemoryBank, HallucinationGate, | |
| PauseAndThink, BacktrackHead, SteeringVector, AdaptiveDepth, | |
| ) | |
| surgeon = CortexSurgeon(self.model) | |
| hidden_dim = surgeon.hidden_dim | |
| num_layers = surgeon.num_layers | |
| # Find valid num_heads for cross-attention | |
| num_heads = 8 | |
| while hidden_dim % num_heads != 0 and num_heads > 1: | |
| num_heads -= 1 | |
| middle_layers = list(range(num_layers // 3, 2 * num_layers // 3)) | |
| deep_layers = list(range(2 * num_layers // 3, num_layers)) | |
| surgeon.add_module("memory", MemoryBank( | |
| hidden_dim=hidden_dim, num_slots=32, num_heads=num_heads, | |
| target_layers=middle_layers, | |
| )) | |
| surgeon.add_module("halluc_gate", HallucinationGate( | |
| hidden_dim=hidden_dim, bottleneck_dim=32, | |
| target_layers=deep_layers, | |
| )) | |
| surgeon.add_module("pause_think", PauseAndThink( | |
| hidden_dim=hidden_dim, num_think_tokens=4, | |
| target_layers=middle_layers, | |
| )) | |
| surgeon.add_module("backtrack", BacktrackHead( | |
| hidden_dim=hidden_dim, confidence_bottleneck=32, | |
| num_layers=num_layers, target_layers="all", | |
| )) | |
| surgeon.add_module("steering", SteeringVector( | |
| hidden_dim=hidden_dim, num_directions=2, | |
| direction_names=["truthfulness", "helpfulness"], | |
| target_layers=middle_layers, | |
| )) | |
| surgeon.add_module("adaptive_depth", AdaptiveDepth( | |
| hidden_dim=hidden_dim, target_layers="all", | |
| )) | |
| surgeon.operate(freeze_base=True) | |
| if self.cortex_weights: | |
| surgeon.load_cortex_modules(self.cortex_weights) | |
| print(f" Loaded Cortex weights: {self.cortex_weights}") | |
| report = surgeon.get_parameter_report() | |
| total_cortex = sum(info["trainable"] for info in report.values()) | |
| total_model = sum(p.numel() for p in self.model.parameters()) | |
| self._surgeon = surgeon | |
| return { | |
| "total_cortex_params": total_cortex, | |
| "total_model_params": total_model, | |
| "overhead_pct": total_cortex / total_model * 100, | |
| "per_module": report, | |
| } | |
| def remove_cortex(self): | |
| """Remove Cortex modules and restore base model.""" | |
| if hasattr(self, "_surgeon"): | |
| self._surgeon.remove_all() | |
| del self._surgeon | |
| def run_comparison( | |
| self, | |
| tasks: Optional[List[str]] = None, | |
| n: int = 50, | |
| include_memory: bool = True, | |
| n_passkey: int = 5, | |
| passkey_lengths: Optional[List[int]] = None, | |
| ) -> Dict: | |
| """ | |
| Run full comparison: base model vs Cortex-enhanced. | |
| Args: | |
| tasks: List of task names from TASK_REGISTRY. None = all. | |
| n: Number of examples per task. | |
| include_memory: Whether to run memory benchmarks. | |
| n_passkey: Number of passkey examples per context length. | |
| passkey_lengths: Context lengths for passkey test. | |
| Returns: | |
| Dict with base_results, cortex_results, and comparison. | |
| """ | |
| if tasks is None: | |
| tasks = ["hellaswag", "piqa", "arc-easy", "winogrande"] | |
| results = { | |
| "model": self.model_name, | |
| "device": self.device, | |
| "dtype": str(self.dtype), | |
| "n_per_task": n, | |
| "tasks": tasks, | |
| "base": {}, | |
| "cortex": {}, | |
| "comparison": {}, | |
| } | |
| # ===== BASE MODEL ===== | |
| print("\n" + "=" * 60) | |
| print("PHASE 1: BASE MODEL EVALUATION") | |
| print("=" * 60) | |
| for task_name in tasks: | |
| print(f"\n[BASE] Running {task_name}...") | |
| t0 = time.time() | |
| task_cls = TASK_REGISTRY[task_name] | |
| task = task_cls() if callable(task_cls) else task_cls | |
| result = self._run_loglikelihood_task(task, self.model, n=n) | |
| result["time_seconds"] = time.time() - t0 | |
| results["base"][task_name] = result | |
| print(f" {task_name}: {result['accuracy']:.4f} ({result['correct']}/{result['total']}) " | |
| f"[{result['time_seconds']:.1f}s]") | |
| if include_memory: | |
| print(f"\n[BASE] Running memory benchmarks...") | |
| t0 = time.time() | |
| mem_results = self._run_memory_tasks( | |
| self.model, n_passkey=n_passkey, | |
| passkey_lengths=passkey_lengths, | |
| ) | |
| mem_results["time_seconds"] = time.time() - t0 | |
| results["base"]["memory"] = mem_results | |
| pk = mem_results["passkey_retrieval"]["overall"] | |
| mh = mem_results["multi_hop_memory"] | |
| print(f" passkey: {pk['accuracy']:.4f} ({pk['correct']}/{pk['total']})") | |
| print(f" multi_hop: {mh['accuracy']:.4f} ({mh['correct']}/{mh['total']})") | |
| # ===== CORTEX-ENHANCED MODEL ===== | |
| print("\n" + "=" * 60) | |
| print("PHASE 2: CORTEX-ENHANCED MODEL EVALUATION") | |
| print("=" * 60) | |
| print("\nInjecting Cortex modules...") | |
| module_info = self.inject_cortex() | |
| print(f" Cortex params: {module_info['total_cortex_params']:,} " | |
| f"({module_info['overhead_pct']:.2f}% overhead)") | |
| results["cortex_info"] = module_info | |
| for task_name in tasks: | |
| print(f"\n[CORTEX] Running {task_name}...") | |
| t0 = time.time() | |
| task_cls = TASK_REGISTRY[task_name] | |
| task = task_cls() if callable(task_cls) else task_cls | |
| result = self._run_loglikelihood_task(task, self.model, n=n) | |
| result["time_seconds"] = time.time() - t0 | |
| results["cortex"][task_name] = result | |
| print(f" {task_name}: {result['accuracy']:.4f} ({result['correct']}/{result['total']}) " | |
| f"[{result['time_seconds']:.1f}s]") | |
| if include_memory: | |
| print(f"\n[CORTEX] Running memory benchmarks...") | |
| t0 = time.time() | |
| mem_results = self._run_memory_tasks( | |
| self.model, n_passkey=n_passkey, | |
| passkey_lengths=passkey_lengths, | |
| ) | |
| mem_results["time_seconds"] = time.time() - t0 | |
| results["cortex"]["memory"] = mem_results | |
| pk = mem_results["passkey_retrieval"]["overall"] | |
| mh = mem_results["multi_hop_memory"] | |
| print(f" passkey: {pk['accuracy']:.4f} ({pk['correct']}/{pk['total']})") | |
| print(f" multi_hop: {mh['accuracy']:.4f} ({mh['correct']}/{mh['total']})") | |
| # ===== COMPARISON ===== | |
| print("\n" + "=" * 60) | |
| print("COMPARISON: BASE vs CORTEX") | |
| print("=" * 60) | |
| for task_name in tasks: | |
| base_acc = results["base"][task_name]["accuracy"] | |
| cortex_acc = results["cortex"][task_name]["accuracy"] | |
| delta = cortex_acc - base_acc | |
| symbol = "↑" if delta > 0 else "↓" if delta < 0 else "=" | |
| results["comparison"][task_name] = { | |
| "base": base_acc, | |
| "cortex": cortex_acc, | |
| "delta": delta, | |
| } | |
| print(f" {task_name:20s} base={base_acc:.4f} cortex={cortex_acc:.4f} " | |
| f"Δ={delta:+.4f} {symbol}") | |
| if include_memory: | |
| base_pk = results["base"]["memory"]["passkey_retrieval"]["overall"]["accuracy"] | |
| cortex_pk = results["cortex"]["memory"]["passkey_retrieval"]["overall"]["accuracy"] | |
| base_mh = results["base"]["memory"]["multi_hop_memory"]["accuracy"] | |
| cortex_mh = results["cortex"]["memory"]["multi_hop_memory"]["accuracy"] | |
| results["comparison"]["passkey"] = { | |
| "base": base_pk, "cortex": cortex_pk, "delta": cortex_pk - base_pk, | |
| } | |
| results["comparison"]["multi_hop"] = { | |
| "base": base_mh, "cortex": cortex_mh, "delta": cortex_mh - base_mh, | |
| } | |
| print(f" {'passkey':20s} base={base_pk:.4f} cortex={cortex_pk:.4f} " | |
| f"Δ={cortex_pk - base_pk:+.4f}") | |
| print(f" {'multi_hop':20s} base={base_mh:.4f} cortex={cortex_mh:.4f} " | |
| f"Δ={cortex_mh - base_mh:+.4f}") | |
| # Remove cortex modules to restore base model | |
| self.remove_cortex() | |
| return results | |
| def print_summary(results: Dict): | |
| """Print a formatted summary of benchmark results.""" | |
| print("\n" + "=" * 70) | |
| print(f"BENCHMARK SUMMARY: {results['model']}") | |
| print(f"n={results['n_per_task']} per task, device={results['device']}") | |
| print("=" * 70) | |
| print(f"\n{'Task':22s} {'Base':>8s} {'Cortex':>8s} {'Delta':>8s}") | |
| print("-" * 50) | |
| for task_name, comp in results["comparison"].items(): | |
| delta_str = f"{comp['delta']:+.4f}" | |
| symbol = " ↑" if comp["delta"] > 0.001 else " ↓" if comp["delta"] < -0.001 else " " | |
| print(f"{task_name:22s} {comp['base']:8.4f} {comp['cortex']:8.4f} {delta_str:>8s}{symbol}") | |
| if "cortex_info" in results: | |
| info = results["cortex_info"] | |
| print(f"\nCortex overhead: {info['total_cortex_params']:,} params " | |
| f"({info['overhead_pct']:.2f}%)") | |
| print("=" * 70) | |