Enhance benchmark and Cortex modules with new training utilities and improved state management. Update README with example output for Llama-3.2-1B and add training CLI for Cortex module tuning. Refactor scoring functions to reset Cortex state between examples and ensure consistent output. Modify task handling to ensure proper formatting of input data.
0de2901 | # Cortex: Modular Cognitive Plug-ins for Pretrained LLMs | |
| **Surgically insert new cognitive capabilities into any pretrained transformer LLM β without retraining the base model.** | |
| Cortex is a framework for performing *layer surgery* on pretrained language models. It injects lightweight, composable modules into the transformer's residual stream via PyTorch hooks, adding capabilities that address fundamental LLM failure modes. The base model weights are frozen; only the Cortex modules (~3% parameter overhead) are trainable. | |
| ## Architecture Overview | |
| ``` | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β Pretrained LLM (Frozen) β | |
| β β | |
| β βββββββββββ βββββββββββ βββββββββββ βββββββββββ β | |
| β β Layer 0 ββββΆβ Layer 1 ββββΆβ ... ββββΆβ Layer N ββββΆ β | |
| β ββββββ¬βββββ ββββββ¬βββββ ββββββ¬βββββ ββββββ¬βββββ β | |
| β β β β β β | |
| β ββββββΌβββββ ββββββΌβββββ ββββββΌβββββ ββββββΌβββββ β | |
| β βAdaptive β β Backtrackβ β Memory β β Halluc β β | |
| β β Depth β β Head β β Bank β β Gate β β | |
| β β(gate) β β(correct) β β(read/ β β(suppressβ β | |
| β β β β β β write) β β unsure) β β | |
| β βββββββββββ βββββββββββ βββββββββββ βββββββββββ β | |
| β β β β β β | |
| β ββββββΌβββββ ββββββΌβββββ ββββββΌβββββ β | |
| β βSteering β β Pause β β β β | |
| β β Vector β β& Think β β β β | |
| β β(steer) β β(compute) β β β β | |
| β βββββββββββ βββββββββββ βββββββββββ β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ``` | |
| ## The 6 Modules | |
| ### 1. π§ MemoryBank β Persistent Episodic Memory | |
| **Failure mode:** Limited context window, lost long-range dependencies, no working memory. | |
| Injects a learnable memory matrix **M β β^{NΓD}** into middle transformer layers. Hidden states read from memory via multi-head cross-attention, and write back via LSTM-style gated updates. Memory persists across forward passes, enabling multi-turn memory and long-document reasoning. | |
| - **Injection:** `POST_ATTENTION` (between attention and FFN) | |
| - **Mechanism:** Cross-attention read β Output gate β LSTM-style write | |
| - **Based on:** [LM2: Large Memory Models (Kang et al. 2025)](https://arxiv.org/abs/2502.06049), [WISE (Xia et al. 2024)](https://arxiv.org/abs/2405.14768) | |
| ### 2. π‘οΈ HallucinationGate β Confidence-Based Suppression | |
| **Failure mode:** Hallucination β generating confident but factually wrong content. | |
| A lightweight confidence probe reads the residual stream and outputs a per-token confidence score. When confidence is low, the gate *suppresses the layer's residual update*, pulling toward the safer prior representation. The model effectively learns to say "I don't know" at the representation level. | |
| - **Injection:** `POST_FFN` (after full transformer block) | |
| - **Mechanism:** Confidence probe β Soft gate β Suppress uncertain updates | |
| - **Key insight:** Internal states contain more information about correctness than output distributions β I(Ξ; K(X)|X) β₯ I(Y; K(X)|X) + Ξ | |
| - **Based on:** [InternalInspector (Chen et al. 2024)](https://arxiv.org/abs/2406.12053), [The Map of Misbelief (2025)](https://arxiv.org/abs/2511.10837) | |
| ### 3. π PauseAndThink β Latent Computation Tokens | |
| **Failure mode:** Fixed compute per token, shallow reasoning, limited "thinking time." | |
| Injects K learnable "thinking" token embeddings that attend to all real tokens, perform computation, then compress their information back into the original sequence positions via gated cross-attention. Like chain-of-thought, but entirely in latent space β no extra output tokens needed. | |
| - **Injection:** `RESIDUAL_STREAM` (wraps full block) | |
| - **Mechanism:** Context-conditioned thinking tokens β Attention to real tokens β Gated compression back | |
| - **Based on:** [Pause Tokens (Goyal et al. 2023)](https://arxiv.org/abs/2310.02226), [Thoughtbubbles (2025)](https://arxiv.org/abs/2510.00219) | |
| ### 4. β©οΈ BacktrackHead β Learned Self-Correction | |
| **Failure mode:** Commitment to bad intermediate representations, no backtracking. | |
| Monitors confidence *across layers*. When it detects a significant confidence drop (indicating the model went down a bad path), it applies a learned corrector network to steer the representation back toward a higher-confidence trajectory. Effectively implements architectural self-correction. | |
| - **Injection:** `RESIDUAL_STREAM` (all layers) | |
| - **Mechanism:** Per-layer confidence probe β Drop detection β Bottleneck corrector network | |
| - **Based on:** [GateSkip (2025)](https://arxiv.org/abs/2510.13876), [River-LLM (2025)](https://arxiv.org/abs/2604.18396), [Self-Correction (Welleck et al. 2022)](https://arxiv.org/abs/2211.00053) | |
| ### 5. π§ SteeringVector β Behavioral Control | |
| **Failure mode:** Behavioral inflexibility, inability to control style/truthfulness/safety at runtime. | |
| Maintains named "concept directions" in activation space. Directions can be extracted via contrastive activation analysis (RepE) or learned end-to-end. Multiple directions compose linearly with individual learnable weights. Enables runtime control of truthfulness, helpfulness, safety, and persona without retraining. | |
| - **Injection:** `RESIDUAL_STREAM` (middle layers) | |
| - **Mechanism:** h_new = h + layer_scale Γ Ξ£(Ξ±_i Γ direction_i) | |
| - **Based on:** [Representation Engineering (Zou et al. 2023)](https://arxiv.org/abs/2310.01405) | |
| ### 6. β‘ AdaptiveDepth β Dynamic Layer Skipping | |
| **Failure mode:** Fixed compute depth, wasted computation, overthinking (representation collapse). | |
| Each layer gets a learned gate that decides per-token whether to execute or skip. Easy tokens ("the") skip many layers; hard tokens (complex reasoning) use all of them. Includes budget regularization to target a desired compute fraction. | |
| - **Injection:** `POST_FFN` (all layers) | |
| - **Mechanism:** Gate network β Sigmoid β Scale residual contribution β Budget loss | |
| - **Based on:** [Mixture of Depths (Raposo et al. 2024)](https://arxiv.org/abs/2404.02258), [GateSkip (2025)](https://arxiv.org/abs/2510.13876), [Router-Tuning (2024)](https://arxiv.org/abs/2410.13184) | |
| ## Quick Start | |
| ```python | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from cortex import ( | |
| CortexSurgeon, MemoryBank, HallucinationGate, | |
| PauseAndThink, BacktrackHead, SteeringVector, AdaptiveDepth | |
| ) | |
| # Load any pretrained LLM | |
| model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B") | |
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") | |
| # Create surgeon | |
| surgeon = CortexSurgeon(model) | |
| hidden_dim = surgeon.hidden_dim | |
| # Add modules β each targets specific layers | |
| surgeon.add_module("memory", MemoryBank(hidden_dim=hidden_dim, num_slots=64)) | |
| surgeon.add_module("halluc_gate", HallucinationGate(hidden_dim=hidden_dim)) | |
| surgeon.add_module("pause_think", PauseAndThink(hidden_dim=hidden_dim, num_think_tokens=8)) | |
| surgeon.add_module("backtrack", BacktrackHead(hidden_dim=hidden_dim, num_layers=surgeon.num_layers)) | |
| surgeon.add_module("steering", SteeringVector(hidden_dim=hidden_dim, num_directions=4)) | |
| surgeon.add_module("adaptive_depth", AdaptiveDepth(hidden_dim=hidden_dim)) | |
| # Perform surgery (freezes base model, only Cortex modules train) | |
| surgeon.operate(freeze_base=True) | |
| # Use the model normally β Cortex modules are active | |
| inputs = tokenizer("The meaning of life is", return_tensors="pt") | |
| outputs = model.generate(**inputs, max_new_tokens=50) | |
| print(tokenizer.decode(outputs[0])) | |
| # Toggle modules on/off at runtime | |
| surgeon.modules["halluc_gate"].disable() | |
| # Save only Cortex weights (~3% of model size) | |
| surgeon.save_cortex_modules("cortex_weights.pt") | |
| ``` | |
| ## Benchmark Harness | |
| Cortex includes a comprehensive benchmark harness for comparing base LLMs against Cortex-enhanced versions. It evaluates across standard NLP benchmarks and Cortex-specific capability tests. | |
| ### Standard Benchmarks | |
| | Task | Type | Choices | Dataset | Few-Shot | | |
| |-------------------|-------------------------|---------|-----------------------|----------| | |
| | **HellaSwag** | Commonsense NLI | 4 | `Rowan/hellaswag` | 5-shot | | |
| | **ARC-Easy** | Science QA | 3-5 | `allenai/ai2_arc` | 5-shot | | |
| | **ARC-Challenge** | Science QA (hard) | 3-5 | `allenai/ai2_arc` | 5-shot | | |
| | **PIQA** | Physical intuition | 2 | `gimmaru/piqa` | 0-shot | | |
| | **WinoGrande** | Coreference | 2 | `allenai/winogrande` | 5-shot | | |
| | **MMLU** | Multi-domain knowledge | 4 | `cais/mmlu` | 5-shot | | |
| | **HaluEval** | Hallucination detection | 2 | `pminervini/HaluEval` | 0-shot | | |
| ### Cortex-Specific Benchmarks | |
| | Task | Tests | Method | | |
| |-----------------------|-------------------------------------------|-----------------------------------------------------------------| | |
| | **Passkey Retrieval** | Long-context memory, attention to details | Generation + substring match at 128/256/512/1024 token contexts | | |
| | **Multi-Hop Memory** | Compositional reasoning, fact chaining | Generation + answer extraction from 3-hop fact chains | | |
| ### Running Benchmarks | |
| ```bash | |
| # Quick test (10 examples per task) | |
| python -m benchmark.run_benchmark --n 10 --tasks hellaswag piqa | |
| # Standard suite (50 examples, default tasks) | |
| python -m benchmark.run_benchmark --n 50 | |
| # Full evaluation with all tasks | |
| python -m benchmark.run_benchmark --n 0 --tasks hellaswag piqa arc-easy arc-challenge winogrande mmlu | |
| # Custom model | |
| python -m benchmark.run_benchmark --model meta-llama/Llama-3.2-1B --n 50 | |
| # Save JSON results | |
| python -m benchmark.run_benchmark --n 50 --output results.json | |
| # Skip memory benchmarks | |
| python -m benchmark.run_benchmark --n 50 --no-memory | |
| # Custom passkey test | |
| python -m benchmark.run_benchmark --n 20 --passkey-lengths 128 256 512 1024 --n-passkey 10 | |
| python -m benchmark.run_benchmark --n 10 --model meta-llama/Llama-3.2-1B --tasks hellaswag piqa arc-easy arc-challenge winogrande mmlu | |
| ``` | |
| ### Scoring Method | |
| - **Multiple-choice tasks:** Log-likelihood scoring β computes average log-probability the model assigns to each continuation, picks the highest. This is the standard approach used by lm-evaluation-harness and Open LLM Leaderboard. | |
| - **Generation tasks:** Greedy decode + substring match against expected answer. | |
| ### Example Output (Llama-3.2-1B, n=10) | |
| ``` | |
| ====================================================================== | |
| BENCHMARK SUMMARY: meta-llama/Llama-3.2-1B | |
| n=10 per task, device=mps | |
| ====================================================================== | |
| Task Base Cortex Delta | |
| -------------------------------------------------- | |
| hellaswag 0.6000 0.6000 +0.0000 | |
| piqa 0.2000 0.2000 +0.0000 | |
| arc-easy 0.4000 0.4000 +0.0000 | |
| arc-challenge 0.5000 0.5000 +0.0000 | |
| winogrande 0.6000 0.6000 +0.0000 | |
| mmlu 0.4000 0.4000 +0.0000 | |
| passkey 1.0000 1.0000 +0.0000 | |
| multi_hop 1.0000 1.0000 +0.0000 | |
| Cortex overhead: 53,708,968 params (4.35%) | |
| ====================================================================== | |
| ``` | |
| > **Note:** Cortex modules are untrained at injection and initialize as exact no-ops for model behavior. Freshly injected modules should match the base model; positive deltas require Cortex-specific training or calibrated steering directions. | |
| ### Programmatic Usage | |
| ```python | |
| from benchmark.runner import BenchmarkRunner | |
| runner = BenchmarkRunner(model_name="HuggingFaceTB/SmolLM2-135M") | |
| results = runner.run_comparison( | |
| tasks=["hellaswag", "piqa", "arc-easy"], | |
| n=50, | |
| include_memory=True, | |
| passkey_lengths=[128, 256, 512], | |
| ) | |
| BenchmarkRunner.print_summary(results) | |
| ``` | |
| ## Design Principles | |
| ### 1. Zero-Init for Stable Injection | |
| All modules initialize their output projections to zero (or near-zero via negative gate biases). This means at injection time, the model behaves identically to the original β Cortex modules are "invisible" and gradually learn to contribute during training. | |
| ### 2. Hook-Based Surgery | |
| Modules are injected via PyTorch `register_forward_hook` / `register_forward_pre_hook`. No model code is modified. This works with any HuggingFace `transformers` model that has a standard layer structure. | |
| ### 3. Shared Parameters Across Layers | |
| Each module instance is shared across its target layers. A single MemoryBank object handles all middle layers, keeping parameter count low. | |
| ### 4. Base Model Freezing | |
| By default, all base model parameters are frozen. Only Cortex module parameters are trainable. This means: | |
| - No catastrophic forgetting of the base model's capabilities | |
| - Tiny training cost (~3% of parameters) | |
| - Multiple Cortex configurations can be saved/loaded/swapped | |
| ### 5. Composability | |
| All modules are independent and composable. Use any combination: | |
| - Memory + HallucinationGate for factual QA | |
| - PauseAndThink + AdaptiveDepth for reasoning tasks | |
| - SteeringVector alone for behavioral control | |
| ## Injection Points | |
| | Point | Location | Best For | | |
| |-------------------|-----------------------------|------------------------------------------------------| | |
| | `PRE_ATTENTION` | Before self-attention | Input preprocessing, prefix injection | | |
| | `POST_ATTENTION` | After attention, before FFN | Memory augmentation (reads enhance attention output) | | |
| | `PRE_FFN` | Before FFN | Gate what the FFN processes | | |
| | `POST_FFN` | After full block | Gating, confidence estimation | | |
| | `RESIDUAL_STREAM` | Wraps entire block | Steering vectors, thinking tokens, backtracking | | |
| ## Layer Targeting | |
| ```python | |
| # Target specific layers | |
| surgeon.add_module("mod", module, target_layers=[0, 1, 2, 3]) | |
| # Target all layers | |
| surgeon.add_module("mod", module, target_layers="all") | |
| # Target middle third (best for steering/memory) | |
| surgeon.add_module("mod", module, target_layers="middle") | |
| # Target deep layers (best for output-facing modifications) | |
| surgeon.add_module("mod", module, target_layers="deep") | |
| ``` | |
| ## Compatible Models | |
| Tested and working with any model using the standard `model.layers[i]` structure: | |
| - **LLaMA** family (LLaMA 2/3, CodeLLaMA) | |
| - **Mistral** / Mixtral | |
| - **Qwen2** | |
| - **Gemma** | |
| - **Phi** | |
| - **SmolLM** | |
| - **GPT-2** / GPT-Neo (uses `transformer.h`) | |
| ## Monitoring | |
| ```python | |
| # Hallucination confidence | |
| confidence = surgeon.modules["halluc_gate"].get_confidence() | |
| # Backtracking status | |
| was_triggered = surgeon.modules["backtrack"].was_triggered() | |
| confidence_per_layer = surgeon.modules["backtrack"].get_confidence_history() | |
| # Adaptive depth statistics | |
| gate_stats = surgeon.modules["adaptive_depth"].get_gate_stats() | |
| print(f"Mean gate: {gate_stats['mean']:.3f}, Skip fraction: {gate_stats['skip_frac']:.3f}") | |
| # Steering vector info | |
| for name, info in surgeon.modules["steering"].get_direction_info().items(): | |
| print(f"{name}: alpha={info['alpha']:.3f}") | |
| # Parameter report | |
| report = surgeon.get_parameter_report() | |
| ``` | |
| ## Extracting Steering Directions (RepE) | |
| ```python | |
| # Contrastive activation pairs for "truthfulness" | |
| positive = [ | |
| "I know for certain that the Earth orbits the Sun.", | |
| "Scientific evidence clearly shows vaccines are safe.", | |
| "I don't know the answer to that question.", | |
| ] | |
| negative = [ | |
| "The Earth is flat and NASA is lying.", | |
| "Vaccines cause autism according to my research.", | |
| "I'm absolutely certain about this made-up fact.", | |
| ] | |
| # Extract direction from layer 15 | |
| direction = SteeringVector.extract_direction( | |
| model, positive, negative, tokenizer, | |
| layer_idx=15, device="cuda" | |
| ) | |
| # Set it in the steering module | |
| surgeon.modules["steering"].set_direction("truthfulness", direction, alpha=10.0) | |
| ``` | |
| ## Training | |
| For benchmark-style supervised tuning, use the training CLI. It freezes the base | |
| model, injects Cortex modules, optimizes only Cortex parameters, and saves the | |
| adapter weights: | |
| ```bash | |
| python -m benchmark.train_cortex \ | |
| --model meta-llama/Llama-3.2-1B \ | |
| --tasks hellaswag piqa arc-easy winogrande \ | |
| --n-train 32 \ | |
| --epochs 1 \ | |
| --output cortex_tuned.pt | |
| python -m benchmark.run_benchmark \ | |
| --model meta-llama/Llama-3.2-1B \ | |
| --cortex-weights cortex_tuned.pt \ | |
| --n 50 | |
| ``` | |
| For custom training loops: | |
| ```python | |
| import torch.optim as optim | |
| # Only train Cortex parameters | |
| optimizer = optim.AdamW(surgeon.get_trainable_parameters(), lr=1e-4) | |
| for batch in dataloader: | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| # Add adaptive depth budget loss | |
| loss = loss + surgeon.modules["adaptive_depth"].get_budget_loss() | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| ``` | |
| ## Test Results (SmolLM2-135M) | |
| ``` | |
| β All 9 tests passed | |
| β 4,286,918 Cortex params (3.19% overhead on 134M model) | |
| β Base model: 0 gradients (fully frozen) | |
| β Cortex modules: gradients flowing | |
| β Enable/disable: exact zero diff when disabled | |
| β Generation: produces coherent output | |
| β Save/load: 16.4 KB checkpoint | |
| ``` | |
| ## Citation | |
| If you use Cortex in your research, please cite the papers that inspired each module: | |
| ```bibtex | |
| @article{kang2025lm2, | |
| title={LM2: Large Memory Models}, | |
| author={Kang, et al.}, | |
| journal={arXiv:2502.06049}, | |
| year={2025} | |
| } | |
| @article{chen2024internalinspector, | |
| title={InternalInspector I2: Robust Confidence Estimation in LLMs through Internal States}, | |
| author={Chen, et al.}, | |
| journal={arXiv:2406.12053}, | |
| year={2024} | |
| } | |
| @article{goyal2023think, | |
| title={Think before you speak: Training Language Models With Pause Tokens}, | |
| author={Goyal, et al.}, | |
| journal={arXiv:2310.02226}, | |
| year={2023} | |
| } | |
| @article{zou2023representation, | |
| title={Representation Engineering: A Top-Down Approach to AI Transparency}, | |
| author={Zou, et al.}, | |
| journal={arXiv:2310.01405}, | |
| year={2023} | |
| } | |
| @article{raposo2024mixture, | |
| title={Mixture of Depths}, | |
| author={Raposo, et al.}, | |
| journal={arXiv:2404.02258}, | |
| year={2024} | |
| } | |
| ``` | |
| ## License | |
| Apache 2.0 | |