Commit ·
0de2901
1
Parent(s): 0ac64e3
Enhance benchmark and Cortex modules with new training utilities and improved state management. Update README with example output for Llama-3.2-1B and add training CLI for Cortex module tuning. Refactor scoring functions to reset Cortex state between examples and ensure consistent output. Modify task handling to ensure proper formatting of input data.
Browse files- README.md +34 -12
- benchmark/run_benchmark.py +5 -0
- benchmark/runner.py +6 -0
- benchmark/scoring.py +11 -0
- benchmark/tasks.py +9 -3
- benchmark/train_cortex.py +161 -0
- benchmark/tuning.py +97 -0
- cortex/adaptive_depth.py +21 -9
- cortex/backtrack_head.py +6 -1
- cortex/core.py +11 -3
- cortex/hallucination_gate.py +18 -5
- cortex/memory_bank.py +5 -1
- test_cortex.py +2 -2
README.md
CHANGED
|
@@ -181,28 +181,30 @@ python -m benchmark.run_benchmark --n 10 --model meta-llama/Llama-3.2-1B --tasks
|
|
| 181 |
- **Multiple-choice tasks:** Log-likelihood scoring — computes average log-probability the model assigns to each continuation, picks the highest. This is the standard approach used by lm-evaluation-harness and Open LLM Leaderboard.
|
| 182 |
- **Generation tasks:** Greedy decode + substring match against expected answer.
|
| 183 |
|
| 184 |
-
### Example Output (
|
| 185 |
|
| 186 |
```
|
| 187 |
======================================================================
|
| 188 |
-
BENCHMARK SUMMARY:
|
| 189 |
-
n=
|
| 190 |
======================================================================
|
| 191 |
|
| 192 |
Task Base Cortex Delta
|
| 193 |
--------------------------------------------------
|
| 194 |
-
hellaswag 0.
|
| 195 |
-
piqa 0.
|
| 196 |
-
arc-easy 0.
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
| 202 |
======================================================================
|
| 203 |
```
|
| 204 |
|
| 205 |
-
> **Note:** Cortex modules are untrained at injection
|
| 206 |
|
| 207 |
### Programmatic Usage
|
| 208 |
|
|
@@ -328,6 +330,26 @@ surgeon.modules["steering"].set_direction("truthfulness", direction, alpha=10.0)
|
|
| 328 |
|
| 329 |
## Training
|
| 330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
```python
|
| 332 |
import torch.optim as optim
|
| 333 |
|
|
|
|
| 181 |
- **Multiple-choice tasks:** Log-likelihood scoring — computes average log-probability the model assigns to each continuation, picks the highest. This is the standard approach used by lm-evaluation-harness and Open LLM Leaderboard.
|
| 182 |
- **Generation tasks:** Greedy decode + substring match against expected answer.
|
| 183 |
|
| 184 |
+
### Example Output (Llama-3.2-1B, n=10)
|
| 185 |
|
| 186 |
```
|
| 187 |
======================================================================
|
| 188 |
+
BENCHMARK SUMMARY: meta-llama/Llama-3.2-1B
|
| 189 |
+
n=10 per task, device=mps
|
| 190 |
======================================================================
|
| 191 |
|
| 192 |
Task Base Cortex Delta
|
| 193 |
--------------------------------------------------
|
| 194 |
+
hellaswag 0.6000 0.6000 +0.0000
|
| 195 |
+
piqa 0.2000 0.2000 +0.0000
|
| 196 |
+
arc-easy 0.4000 0.4000 +0.0000
|
| 197 |
+
arc-challenge 0.5000 0.5000 +0.0000
|
| 198 |
+
winogrande 0.6000 0.6000 +0.0000
|
| 199 |
+
mmlu 0.4000 0.4000 +0.0000
|
| 200 |
+
passkey 1.0000 1.0000 +0.0000
|
| 201 |
+
multi_hop 1.0000 1.0000 +0.0000
|
| 202 |
+
|
| 203 |
+
Cortex overhead: 53,708,968 params (4.35%)
|
| 204 |
======================================================================
|
| 205 |
```
|
| 206 |
|
| 207 |
+
> **Note:** Cortex modules are untrained at injection and initialize as exact no-ops for model behavior. Freshly injected modules should match the base model; positive deltas require Cortex-specific training or calibrated steering directions.
|
| 208 |
|
| 209 |
### Programmatic Usage
|
| 210 |
|
|
|
|
| 330 |
|
| 331 |
## Training
|
| 332 |
|
| 333 |
+
For benchmark-style supervised tuning, use the training CLI. It freezes the base
|
| 334 |
+
model, injects Cortex modules, optimizes only Cortex parameters, and saves the
|
| 335 |
+
adapter weights:
|
| 336 |
+
|
| 337 |
+
```bash
|
| 338 |
+
python -m benchmark.train_cortex \
|
| 339 |
+
--model meta-llama/Llama-3.2-1B \
|
| 340 |
+
--tasks hellaswag piqa arc-easy winogrande \
|
| 341 |
+
--n-train 32 \
|
| 342 |
+
--epochs 1 \
|
| 343 |
+
--output cortex_tuned.pt
|
| 344 |
+
|
| 345 |
+
python -m benchmark.run_benchmark \
|
| 346 |
+
--model meta-llama/Llama-3.2-1B \
|
| 347 |
+
--cortex-weights cortex_tuned.pt \
|
| 348 |
+
--n 50
|
| 349 |
+
```
|
| 350 |
+
|
| 351 |
+
For custom training loops:
|
| 352 |
+
|
| 353 |
```python
|
| 354 |
import torch.optim as optim
|
| 355 |
|
benchmark/run_benchmark.py
CHANGED
|
@@ -68,6 +68,10 @@ def main():
|
|
| 68 |
"--output", type=str, default=None,
|
| 69 |
help="Path to save JSON results",
|
| 70 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
args = parser.parse_args()
|
| 73 |
|
|
@@ -77,6 +81,7 @@ def main():
|
|
| 77 |
model_name=args.model,
|
| 78 |
device=args.device,
|
| 79 |
dtype=args.dtype,
|
|
|
|
| 80 |
)
|
| 81 |
|
| 82 |
n = args.n if args.n > 0 else None
|
|
|
|
| 68 |
"--output", type=str, default=None,
|
| 69 |
help="Path to save JSON results",
|
| 70 |
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--cortex-weights", type=str, default=None,
|
| 73 |
+
help="Optional Cortex weights file to load before the Cortex phase",
|
| 74 |
+
)
|
| 75 |
|
| 76 |
args = parser.parse_args()
|
| 77 |
|
|
|
|
| 81 |
model_name=args.model,
|
| 82 |
device=args.device,
|
| 83 |
dtype=args.dtype,
|
| 84 |
+
cortex_weights=args.cortex_weights,
|
| 85 |
)
|
| 86 |
|
| 87 |
n = args.n if args.n > 0 else None
|
benchmark/runner.py
CHANGED
|
@@ -40,8 +40,10 @@ class BenchmarkRunner:
|
|
| 40 |
model_name: str = "HuggingFaceTB/SmolLM2-135M",
|
| 41 |
device: str = "auto",
|
| 42 |
dtype: str = "float32",
|
|
|
|
| 43 |
):
|
| 44 |
self.model_name = model_name
|
|
|
|
| 45 |
|
| 46 |
if device == "auto":
|
| 47 |
self.device = resolve_torch_device("auto")
|
|
@@ -167,6 +169,10 @@ class BenchmarkRunner:
|
|
| 167 |
))
|
| 168 |
|
| 169 |
surgeon.operate(freeze_base=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
report = surgeon.get_parameter_report()
|
| 172 |
total_cortex = sum(info["trainable"] for info in report.values())
|
|
|
|
| 40 |
model_name: str = "HuggingFaceTB/SmolLM2-135M",
|
| 41 |
device: str = "auto",
|
| 42 |
dtype: str = "float32",
|
| 43 |
+
cortex_weights: Optional[str] = None,
|
| 44 |
):
|
| 45 |
self.model_name = model_name
|
| 46 |
+
self.cortex_weights = cortex_weights
|
| 47 |
|
| 48 |
if device == "auto":
|
| 49 |
self.device = resolve_torch_device("auto")
|
|
|
|
| 169 |
))
|
| 170 |
|
| 171 |
surgeon.operate(freeze_base=True)
|
| 172 |
+
|
| 173 |
+
if self.cortex_weights:
|
| 174 |
+
surgeon.load_cortex_modules(self.cortex_weights)
|
| 175 |
+
print(f" Loaded Cortex weights: {self.cortex_weights}")
|
| 176 |
|
| 177 |
report = surgeon.get_parameter_report()
|
| 178 |
total_cortex = sum(info["trainable"] for info in report.values())
|
benchmark/scoring.py
CHANGED
|
@@ -17,6 +17,15 @@ import re
|
|
| 17 |
from cortex.torch_device import resolve_torch_device
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
@torch.no_grad()
|
| 21 |
def log_likelihood_score(
|
| 22 |
model,
|
|
@@ -63,6 +72,7 @@ def log_likelihood_score(
|
|
| 63 |
|
| 64 |
# Forward pass
|
| 65 |
input_ids = torch.tensor([full_ids], device=device)
|
|
|
|
| 66 |
|
| 67 |
# Truncate if too long for model
|
| 68 |
max_len = getattr(model.config, "max_position_embeddings", 2048)
|
|
@@ -121,6 +131,7 @@ def generate_and_check(
|
|
| 121 |
if device is None:
|
| 122 |
device = resolve_torch_device("auto")
|
| 123 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
|
|
|
|
| 124 |
|
| 125 |
# Pad token
|
| 126 |
pad_token_id = tokenizer.pad_token_id
|
|
|
|
| 17 |
from cortex.torch_device import resolve_torch_device
|
| 18 |
|
| 19 |
|
| 20 |
+
def reset_cortex_state(model, batch_size: int = 1):
|
| 21 |
+
"""Reset runtime state for injected Cortex modules between independent examples."""
|
| 22 |
+
surgeon = getattr(model, "_cortex_surgeon", None)
|
| 23 |
+
if surgeon is None:
|
| 24 |
+
return
|
| 25 |
+
for module in surgeon.modules.values():
|
| 26 |
+
module.reset_state(batch_size=batch_size)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
@torch.no_grad()
|
| 30 |
def log_likelihood_score(
|
| 31 |
model,
|
|
|
|
| 72 |
|
| 73 |
# Forward pass
|
| 74 |
input_ids = torch.tensor([full_ids], device=device)
|
| 75 |
+
reset_cortex_state(model, batch_size=input_ids.shape[0])
|
| 76 |
|
| 77 |
# Truncate if too long for model
|
| 78 |
max_len = getattr(model.config, "max_position_embeddings", 2048)
|
|
|
|
| 131 |
if device is None:
|
| 132 |
device = resolve_torch_device("auto")
|
| 133 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
|
| 134 |
+
reset_cortex_state(model, batch_size=inputs["input_ids"].shape[0])
|
| 135 |
|
| 136 |
# Pad token
|
| 137 |
pad_token_id = tokenizer.pad_token_id
|
benchmark/tasks.py
CHANGED
|
@@ -75,7 +75,10 @@ class HellaSwag(BenchmarkTask):
|
|
| 75 |
train_examples = []
|
| 76 |
for row in few_shot_ds:
|
| 77 |
ctx = row["ctx"]
|
| 78 |
-
endings =
|
|
|
|
|
|
|
|
|
|
| 79 |
gold = int(row["label"])
|
| 80 |
train_examples.append({
|
| 81 |
"context": ctx,
|
|
@@ -85,7 +88,10 @@ class HellaSwag(BenchmarkTask):
|
|
| 85 |
|
| 86 |
for row in ds:
|
| 87 |
ctx = row["ctx"]
|
| 88 |
-
endings =
|
|
|
|
|
|
|
|
|
|
| 89 |
gold = int(row["label"])
|
| 90 |
examples.append({
|
| 91 |
"context": ctx,
|
|
@@ -132,7 +138,7 @@ class ARC(BenchmarkTask):
|
|
| 132 |
choice_str = " ".join(f"{l}) {t}" for l, t in zip(labels, texts))
|
| 133 |
context = f"Question: {question}\n{choice_str}\nAnswer:"
|
| 134 |
|
| 135 |
-
continuations = [f" {
|
| 136 |
|
| 137 |
return {
|
| 138 |
"context": context,
|
|
|
|
| 75 |
train_examples = []
|
| 76 |
for row in few_shot_ds:
|
| 77 |
ctx = row["ctx"]
|
| 78 |
+
endings = [
|
| 79 |
+
ending if ending.startswith(" ") else f" {ending}"
|
| 80 |
+
for ending in row["endings"]
|
| 81 |
+
]
|
| 82 |
gold = int(row["label"])
|
| 83 |
train_examples.append({
|
| 84 |
"context": ctx,
|
|
|
|
| 88 |
|
| 89 |
for row in ds:
|
| 90 |
ctx = row["ctx"]
|
| 91 |
+
endings = [
|
| 92 |
+
ending if ending.startswith(" ") else f" {ending}"
|
| 93 |
+
for ending in row["endings"]
|
| 94 |
+
]
|
| 95 |
gold = int(row["label"])
|
| 96 |
examples.append({
|
| 97 |
"context": ctx,
|
|
|
|
| 138 |
choice_str = " ".join(f"{l}) {t}" for l, t in zip(labels, texts))
|
| 139 |
context = f"Question: {question}\n{choice_str}\nAnswer:"
|
| 140 |
|
| 141 |
+
continuations = [f" {l}" for l in labels]
|
| 142 |
|
| 143 |
return {
|
| 144 |
"context": context,
|
benchmark/train_cortex.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Supervised Cortex adapter tuning.
|
| 4 |
+
|
| 5 |
+
This trains only Cortex module parameters against the same multiple-choice
|
| 6 |
+
log-likelihood objective used by the benchmark runner. It is intended as a
|
| 7 |
+
small, explicit tuning step before expecting Cortex to outperform the base
|
| 8 |
+
model.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import os
|
| 13 |
+
import random
|
| 14 |
+
import sys
|
| 15 |
+
import time
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
# Ensure parent directory is on path for imports
|
| 20 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 21 |
+
|
| 22 |
+
from benchmark.runner import BenchmarkRunner
|
| 23 |
+
from benchmark.tasks import TASK_REGISTRY
|
| 24 |
+
from benchmark.tuning import cortex_auxiliary_loss, multiple_choice_loss
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_examples(task_names, n_per_task, seed):
|
| 28 |
+
examples = []
|
| 29 |
+
for task_name in task_names:
|
| 30 |
+
task_cls = TASK_REGISTRY[task_name]
|
| 31 |
+
task = task_cls() if callable(task_cls) else task_cls
|
| 32 |
+
task_examples = task.load_examples(n=n_per_task, seed=seed)
|
| 33 |
+
examples.extend((task_name, ex) for ex in task_examples)
|
| 34 |
+
print(f"Loaded {len(task_examples)} examples for {task_name}")
|
| 35 |
+
return examples
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main():
|
| 39 |
+
parser = argparse.ArgumentParser(description="Train Cortex modules on benchmark-style MC data")
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--model", type=str, default="HuggingFaceTB/SmolLM2-135M",
|
| 42 |
+
help="HuggingFace model ID to tune",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--tasks", nargs="+", default=["hellaswag", "piqa", "arc-easy", "winogrande"],
|
| 46 |
+
help="Tasks to train on",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--n-train", type=int, default=8,
|
| 50 |
+
help="Examples per task for tuning",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument("--epochs", type=int, default=1)
|
| 53 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
| 54 |
+
parser.add_argument("--weight-decay", type=float, default=0.01)
|
| 55 |
+
parser.add_argument("--max-grad-norm", type=float, default=1.0)
|
| 56 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--device", type=str, default="auto",
|
| 59 |
+
help="Device: cuda, mps, cpu, or auto",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--dtype", type=str, default="float32",
|
| 63 |
+
choices=["float32", "float16", "bfloat16"],
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--init-cortex-weights", type=str, default=None,
|
| 67 |
+
help="Optional Cortex weights to resume from",
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--output", type=str, default="cortex_tuned.pt",
|
| 71 |
+
help="Path to save tuned Cortex weights",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument("--log-every", type=int, default=4)
|
| 74 |
+
args = parser.parse_args()
|
| 75 |
+
|
| 76 |
+
random.seed(args.seed)
|
| 77 |
+
torch.manual_seed(args.seed)
|
| 78 |
+
|
| 79 |
+
runner = BenchmarkRunner(
|
| 80 |
+
model_name=args.model,
|
| 81 |
+
device=args.device,
|
| 82 |
+
dtype=args.dtype,
|
| 83 |
+
cortex_weights=args.init_cortex_weights,
|
| 84 |
+
)
|
| 85 |
+
runner.inject_cortex()
|
| 86 |
+
|
| 87 |
+
model = runner.model
|
| 88 |
+
tokenizer = runner.tokenizer
|
| 89 |
+
surgeon = runner._surgeon
|
| 90 |
+
model.train()
|
| 91 |
+
|
| 92 |
+
examples = load_examples(args.tasks, args.n_train, args.seed)
|
| 93 |
+
if not examples:
|
| 94 |
+
raise RuntimeError("No training examples loaded")
|
| 95 |
+
|
| 96 |
+
trainable_params = list(surgeon.get_trainable_parameters())
|
| 97 |
+
optimizer = torch.optim.AdamW(
|
| 98 |
+
trainable_params,
|
| 99 |
+
lr=args.lr,
|
| 100 |
+
weight_decay=args.weight_decay,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
print(f"Training on {len(examples)} examples for {args.epochs} epoch(s)")
|
| 104 |
+
start = time.time()
|
| 105 |
+
|
| 106 |
+
for epoch in range(args.epochs):
|
| 107 |
+
rng = random.Random(args.seed + epoch)
|
| 108 |
+
rng.shuffle(examples)
|
| 109 |
+
|
| 110 |
+
total_loss = 0.0
|
| 111 |
+
correct = 0
|
| 112 |
+
seen = 0
|
| 113 |
+
skipped = 0
|
| 114 |
+
|
| 115 |
+
for step, (task_name, example) in enumerate(examples, start=1):
|
| 116 |
+
optimizer.zero_grad(set_to_none=True)
|
| 117 |
+
|
| 118 |
+
loss, pred = multiple_choice_loss(model, tokenizer, example, runner.device)
|
| 119 |
+
if loss is None:
|
| 120 |
+
skipped += 1
|
| 121 |
+
continue
|
| 122 |
+
|
| 123 |
+
aux_loss = cortex_auxiliary_loss(model)
|
| 124 |
+
train_loss = loss + aux_loss
|
| 125 |
+
train_loss.backward()
|
| 126 |
+
|
| 127 |
+
if args.max_grad_norm > 0:
|
| 128 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)
|
| 129 |
+
|
| 130 |
+
optimizer.step()
|
| 131 |
+
|
| 132 |
+
seen += 1
|
| 133 |
+
total_loss += float(train_loss.detach().cpu())
|
| 134 |
+
correct += int(pred == example["gold_idx"])
|
| 135 |
+
|
| 136 |
+
if step % args.log_every == 0 or step == len(examples):
|
| 137 |
+
avg_loss = total_loss / max(seen, 1)
|
| 138 |
+
acc = correct / max(seen, 1)
|
| 139 |
+
print(
|
| 140 |
+
f"epoch={epoch + 1} step={step}/{len(examples)} "
|
| 141 |
+
f"task={task_name} loss={avg_loss:.4f} acc={acc:.3f}"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
avg_loss = total_loss / max(seen, 1)
|
| 145 |
+
acc = correct / max(seen, 1)
|
| 146 |
+
print(
|
| 147 |
+
f"Epoch {epoch + 1} done: loss={avg_loss:.4f} "
|
| 148 |
+
f"acc={acc:.3f} skipped={skipped}"
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
output_dir = os.path.dirname(args.output)
|
| 152 |
+
if output_dir:
|
| 153 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 154 |
+
|
| 155 |
+
surgeon.save_cortex_modules(args.output)
|
| 156 |
+
elapsed = time.time() - start
|
| 157 |
+
print(f"Saved Cortex weights to {args.output} [{elapsed:.1f}s]")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
if __name__ == "__main__":
|
| 161 |
+
main()
|
benchmark/tuning.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training utilities for supervised Cortex adapter tuning.
|
| 3 |
+
|
| 4 |
+
These helpers keep the base model frozen and optimize only the modules managed by
|
| 5 |
+
CortexSurgeon. They intentionally mirror benchmark log-likelihood scoring so a
|
| 6 |
+
small tuning run optimizes the same multiple-choice objective being evaluated.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from typing import Dict, List, Optional, Tuple
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
from benchmark.scoring import reset_cortex_state
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def continuation_log_likelihood(
|
| 20 |
+
model,
|
| 21 |
+
tokenizer,
|
| 22 |
+
context: str,
|
| 23 |
+
continuation: str,
|
| 24 |
+
device: str,
|
| 25 |
+
) -> Optional[torch.Tensor]:
|
| 26 |
+
"""Differentiable average continuation log-likelihood."""
|
| 27 |
+
ctx_ids = tokenizer.encode(context, add_special_tokens=False)
|
| 28 |
+
full_ids = tokenizer.encode(context + continuation, add_special_tokens=False)
|
| 29 |
+
|
| 30 |
+
cont_start = len(ctx_ids)
|
| 31 |
+
cont_length = len(full_ids) - cont_start
|
| 32 |
+
if cont_start <= 0 or cont_length <= 0:
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
input_ids = torch.tensor([full_ids], device=device)
|
| 36 |
+
max_len = getattr(model.config, "max_position_embeddings", 2048)
|
| 37 |
+
if input_ids.shape[1] > max_len:
|
| 38 |
+
input_ids = input_ids[:, :max_len]
|
| 39 |
+
cont_length = min(cont_length, max_len - cont_start)
|
| 40 |
+
if cont_length <= 0:
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
reset_cortex_state(model, batch_size=input_ids.shape[0])
|
| 44 |
+
outputs = model(input_ids)
|
| 45 |
+
logits = outputs.logits
|
| 46 |
+
|
| 47 |
+
shift_logits = logits[0, cont_start - 1 : cont_start + cont_length - 1, :]
|
| 48 |
+
shift_labels = input_ids[0, cont_start : cont_start + cont_length]
|
| 49 |
+
log_probs = F.log_softmax(shift_logits, dim=-1)
|
| 50 |
+
token_log_probs = log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1)
|
| 51 |
+
return token_log_probs.mean()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def multiple_choice_loss(
|
| 55 |
+
model,
|
| 56 |
+
tokenizer,
|
| 57 |
+
example: Dict,
|
| 58 |
+
device: str,
|
| 59 |
+
) -> Tuple[Optional[torch.Tensor], Optional[int]]:
|
| 60 |
+
"""
|
| 61 |
+
Cross-entropy over continuation log-likelihoods.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
(loss, prediction). If an example cannot be scored, both are None.
|
| 65 |
+
"""
|
| 66 |
+
scores: List[torch.Tensor] = []
|
| 67 |
+
for continuation in example["continuations"]:
|
| 68 |
+
score = continuation_log_likelihood(
|
| 69 |
+
model, tokenizer, example["context"], continuation, device
|
| 70 |
+
)
|
| 71 |
+
if score is None:
|
| 72 |
+
return None, None
|
| 73 |
+
scores.append(score)
|
| 74 |
+
|
| 75 |
+
logits = torch.stack(scores).unsqueeze(0)
|
| 76 |
+
gold = torch.tensor([example["gold_idx"]], device=device)
|
| 77 |
+
loss = F.cross_entropy(logits, gold)
|
| 78 |
+
pred = int(logits.argmax(dim=-1).item())
|
| 79 |
+
return loss, pred
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def cortex_auxiliary_loss(model) -> torch.Tensor:
|
| 83 |
+
"""Collect differentiable auxiliary losses exposed by Cortex modules."""
|
| 84 |
+
device = next(model.parameters()).device
|
| 85 |
+
surgeon = getattr(model, "_cortex_surgeon", None)
|
| 86 |
+
if surgeon is None:
|
| 87 |
+
return torch.tensor(0.0, device=device)
|
| 88 |
+
|
| 89 |
+
losses = []
|
| 90 |
+
for module in surgeon.modules.values():
|
| 91 |
+
get_budget_loss = getattr(module, "get_budget_loss", None)
|
| 92 |
+
if get_budget_loss is not None:
|
| 93 |
+
losses.append(get_budget_loss())
|
| 94 |
+
|
| 95 |
+
if not losses:
|
| 96 |
+
return torch.tensor(0.0, device=device)
|
| 97 |
+
return torch.stack([loss.to(device) for loss in losses]).sum()
|
cortex/adaptive_depth.py
CHANGED
|
@@ -77,6 +77,10 @@ class AdaptiveDepth(CortexModule):
|
|
| 77 |
|
| 78 |
# Initialize gate to be "open" (execute layer) by default
|
| 79 |
nn.init.constant_(self.gate_net[-1].bias, 2.0) # sigmoid(2) ≈ 0.88
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
# Buffers for monitoring
|
| 82 |
self.register_buffer("_pre_layer_hidden", None, persistent=False)
|
|
@@ -91,6 +95,7 @@ class AdaptiveDepth(CortexModule):
|
|
| 91 |
self,
|
| 92 |
hidden_states: torch.Tensor,
|
| 93 |
layer_idx: int,
|
|
|
|
| 94 |
**kwargs
|
| 95 |
) -> torch.Tensor:
|
| 96 |
"""
|
|
@@ -103,22 +108,29 @@ class AdaptiveDepth(CortexModule):
|
|
| 103 |
"""
|
| 104 |
# Compute gate value per token
|
| 105 |
gate_logit = self.gate_net(hidden_states) / self.temperature # [B, T, 1]
|
| 106 |
-
|
| 107 |
|
| 108 |
# Straight-through estimator for hard gating
|
| 109 |
if self.gate_type == "straight_through" and self.training:
|
| 110 |
-
hard_gate = (
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
self._gate_values =
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
# Budget regularization loss
|
| 119 |
-
avg_gate =
|
| 120 |
budget_loss = self.budget_loss_weight * (avg_gate - self.target_budget).pow(2)
|
| 121 |
-
self._budget_loss = budget_loss
|
| 122 |
|
| 123 |
return gated_output
|
| 124 |
|
|
@@ -139,4 +151,4 @@ class AdaptiveDepth(CortexModule):
|
|
| 139 |
|
| 140 |
def extra_repr(self):
|
| 141 |
return (f"hidden_dim={self.hidden_dim}, target_budget={self.target_budget}, "
|
| 142 |
-
f"gate_type={self.gate_type}, {super().extra_repr()}")
|
|
|
|
| 77 |
|
| 78 |
# Initialize gate to be "open" (execute layer) by default
|
| 79 |
nn.init.constant_(self.gate_net[-1].bias, 2.0) # sigmoid(2) ≈ 0.88
|
| 80 |
+
|
| 81 |
+
# Blend from identity toward the learned depth gate during training.
|
| 82 |
+
# Initial value gives an exact no-op, preserving pretrained behavior.
|
| 83 |
+
self.gate_residual_scale = nn.Parameter(torch.tensor(0.0))
|
| 84 |
|
| 85 |
# Buffers for monitoring
|
| 86 |
self.register_buffer("_pre_layer_hidden", None, persistent=False)
|
|
|
|
| 95 |
self,
|
| 96 |
hidden_states: torch.Tensor,
|
| 97 |
layer_idx: int,
|
| 98 |
+
pre_layer_hidden: Optional[torch.Tensor] = None,
|
| 99 |
**kwargs
|
| 100 |
) -> torch.Tensor:
|
| 101 |
"""
|
|
|
|
| 108 |
"""
|
| 109 |
# Compute gate value per token
|
| 110 |
gate_logit = self.gate_net(hidden_states) / self.temperature # [B, T, 1]
|
| 111 |
+
learned_gate = torch.sigmoid(gate_logit)
|
| 112 |
|
| 113 |
# Straight-through estimator for hard gating
|
| 114 |
if self.gate_type == "straight_through" and self.training:
|
| 115 |
+
hard_gate = (learned_gate > 0.5).float()
|
| 116 |
+
learned_gate = hard_gate - learned_gate.detach() + learned_gate # STE
|
| 117 |
+
|
| 118 |
+
# At initialization, gate_residual_scale = 0 and effective_gate = 1, so
|
| 119 |
+
# the injected module preserves the original model exactly.
|
| 120 |
+
effective_gate = 1.0 + self.gate_residual_scale * (learned_gate - 1.0)
|
| 121 |
|
| 122 |
+
self._gate_values = effective_gate.detach()
|
| 123 |
|
| 124 |
+
if pre_layer_hidden is not None and pre_layer_hidden.shape == hidden_states.shape:
|
| 125 |
+
residual_update = hidden_states - pre_layer_hidden
|
| 126 |
+
gated_output = pre_layer_hidden + effective_gate * residual_update
|
| 127 |
+
else:
|
| 128 |
+
gated_output = hidden_states
|
| 129 |
|
| 130 |
# Budget regularization loss
|
| 131 |
+
avg_gate = learned_gate.mean()
|
| 132 |
budget_loss = self.budget_loss_weight * (avg_gate - self.target_budget).pow(2)
|
| 133 |
+
self._budget_loss = budget_loss
|
| 134 |
|
| 135 |
return gated_output
|
| 136 |
|
|
|
|
| 151 |
|
| 152 |
def extra_repr(self):
|
| 153 |
return (f"hidden_dim={self.hidden_dim}, target_budget={self.target_budget}, "
|
| 154 |
+
f"gate_type={self.gate_type}, {super().extra_repr()}")
|
cortex/backtrack_head.py
CHANGED
|
@@ -153,6 +153,11 @@ class BacktrackHead(CortexModule):
|
|
| 153 |
def get_confidence_history(self) -> torch.Tensor:
|
| 154 |
"""Return the confidence scores across all layers from the last forward pass."""
|
| 155 |
return self._confidence_history.clone()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
def was_triggered(self) -> bool:
|
| 158 |
"""Whether backtracking was triggered in the last forward pass."""
|
|
@@ -160,4 +165,4 @@ class BacktrackHead(CortexModule):
|
|
| 160 |
|
| 161 |
def extra_repr(self):
|
| 162 |
return (f"hidden_dim={self.hidden_dim}, drop_threshold={self.drop_threshold}, "
|
| 163 |
-
f"{super().extra_repr()}")
|
|
|
|
| 153 |
def get_confidence_history(self) -> torch.Tensor:
|
| 154 |
"""Return the confidence scores across all layers from the last forward pass."""
|
| 155 |
return self._confidence_history.clone()
|
| 156 |
+
|
| 157 |
+
def reset_state(self, batch_size: int = 1):
|
| 158 |
+
"""Clear confidence history between independent examples."""
|
| 159 |
+
self._confidence_history.zero_()
|
| 160 |
+
self._last_triggered = torch.tensor(False, device=self._last_triggered.device)
|
| 161 |
|
| 162 |
def was_triggered(self) -> bool:
|
| 163 |
"""Whether backtracking was triggered in the last forward pass."""
|
|
|
|
| 165 |
|
| 166 |
def extra_repr(self):
|
| 167 |
return (f"hidden_dim={self.hidden_dim}, drop_threshold={self.drop_threshold}, "
|
| 168 |
+
f"{super().extra_repr()}")
|
cortex/core.py
CHANGED
|
@@ -85,6 +85,10 @@ class CortexModule(ABC, nn.Module):
|
|
| 85 |
for hook in self._hooks:
|
| 86 |
hook.remove()
|
| 87 |
self._hooks.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
def enable(self):
|
| 90 |
self._active = True
|
|
@@ -352,6 +356,7 @@ class CortexSurgeon:
|
|
| 352 |
logger.info(f"Injected '{name}' into layers {target_layer_idxs}")
|
| 353 |
|
| 354 |
self._operated = True
|
|
|
|
| 355 |
|
| 356 |
total_params = sum(p.numel() for p in self.model.parameters())
|
| 357 |
cortex_params = sum(
|
|
@@ -371,11 +376,12 @@ class CortexSurgeon:
|
|
| 371 |
def post_ffn_hook(mod, inp, output, _module=module, _layer_idx=layer_idx):
|
| 372 |
if not _module.is_active:
|
| 373 |
return output
|
|
|
|
| 374 |
if isinstance(output, tuple):
|
| 375 |
hidden = output[0]
|
| 376 |
-
hidden = _module(hidden, layer_idx=_layer_idx)
|
| 377 |
return (hidden,) + output[1:]
|
| 378 |
-
return _module(output, layer_idx=_layer_idx)
|
| 379 |
return layer.register_forward_hook(post_ffn_hook)
|
| 380 |
|
| 381 |
elif point == InjectionPoint.PRE_ATTENTION:
|
|
@@ -455,6 +461,8 @@ class CortexSurgeon:
|
|
| 455 |
|
| 456 |
self.modules.clear()
|
| 457 |
self._operated = False
|
|
|
|
|
|
|
| 458 |
logger.info("All Cortex modules removed, model restored")
|
| 459 |
|
| 460 |
def get_trainable_parameters(self):
|
|
@@ -490,4 +498,4 @@ class CortexSurgeon:
|
|
| 490 |
for name, module in self.modules.items():
|
| 491 |
if name in state:
|
| 492 |
module.load_state_dict(state[name]["state_dict"])
|
| 493 |
-
logger.info(f"Loaded weights for '{name}'")
|
|
|
|
| 85 |
for hook in self._hooks:
|
| 86 |
hook.remove()
|
| 87 |
self._hooks.clear()
|
| 88 |
+
|
| 89 |
+
def reset_state(self, batch_size: int = 1):
|
| 90 |
+
"""Reset per-example runtime state, if the module keeps any."""
|
| 91 |
+
pass
|
| 92 |
|
| 93 |
def enable(self):
|
| 94 |
self._active = True
|
|
|
|
| 356 |
logger.info(f"Injected '{name}' into layers {target_layer_idxs}")
|
| 357 |
|
| 358 |
self._operated = True
|
| 359 |
+
setattr(self.model, "_cortex_surgeon", self)
|
| 360 |
|
| 361 |
total_params = sum(p.numel() for p in self.model.parameters())
|
| 362 |
cortex_params = sum(
|
|
|
|
| 376 |
def post_ffn_hook(mod, inp, output, _module=module, _layer_idx=layer_idx):
|
| 377 |
if not _module.is_active:
|
| 378 |
return output
|
| 379 |
+
pre_hidden = inp[0] if isinstance(inp, tuple) and len(inp) > 0 else None
|
| 380 |
if isinstance(output, tuple):
|
| 381 |
hidden = output[0]
|
| 382 |
+
hidden = _module(hidden, layer_idx=_layer_idx, pre_layer_hidden=pre_hidden)
|
| 383 |
return (hidden,) + output[1:]
|
| 384 |
+
return _module(output, layer_idx=_layer_idx, pre_layer_hidden=pre_hidden)
|
| 385 |
return layer.register_forward_hook(post_ffn_hook)
|
| 386 |
|
| 387 |
elif point == InjectionPoint.PRE_ATTENTION:
|
|
|
|
| 461 |
|
| 462 |
self.modules.clear()
|
| 463 |
self._operated = False
|
| 464 |
+
if getattr(self.model, "_cortex_surgeon", None) is self:
|
| 465 |
+
delattr(self.model, "_cortex_surgeon")
|
| 466 |
logger.info("All Cortex modules removed, model restored")
|
| 467 |
|
| 468 |
def get_trainable_parameters(self):
|
|
|
|
| 498 |
for name, module in self.modules.items():
|
| 499 |
if name in state:
|
| 500 |
module.load_state_dict(state[name]["state_dict"])
|
| 501 |
+
logger.info(f"Loaded weights for '{name}'")
|
cortex/hallucination_gate.py
CHANGED
|
@@ -102,6 +102,10 @@ class HallucinationGate(CortexModule):
|
|
| 102 |
# Learnable gate bias per dimension — allows the model to learn which
|
| 103 |
# dimensions are safe to suppress vs which should always pass through
|
| 104 |
self.dim_gate = nn.Parameter(torch.zeros(1, 1, hidden_dim))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# Running confidence for monitoring/logging
|
| 107 |
self.register_buffer("_last_confidence", torch.tensor(0.5), persistent=False)
|
|
@@ -111,6 +115,7 @@ class HallucinationGate(CortexModule):
|
|
| 111 |
self,
|
| 112 |
hidden_states: torch.Tensor,
|
| 113 |
layer_idx: int,
|
|
|
|
| 114 |
**kwargs
|
| 115 |
) -> torch.Tensor:
|
| 116 |
"""
|
|
@@ -131,13 +136,21 @@ class HallucinationGate(CortexModule):
|
|
| 131 |
# Compute per-dimension gate
|
| 132 |
dim_bias = torch.sigmoid(self.dim_gate) # [1, 1, D], in (0,1)
|
| 133 |
|
| 134 |
-
# Effective gate: combines token-level confidence with dimension-level bias
|
|
|
|
|
|
|
| 135 |
# High dim_bias = dimension always passes through
|
| 136 |
# Low dim_bias = dimension is gated by confidence
|
| 137 |
-
|
|
|
|
| 138 |
|
| 139 |
-
# Apply gate
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
return gated_output
|
| 143 |
|
|
@@ -148,4 +161,4 @@ class HallucinationGate(CortexModule):
|
|
| 148 |
def extra_repr(self):
|
| 149 |
return (f"hidden_dim={self.hidden_dim}, "
|
| 150 |
f"suppression_strength={self.suppression_strength}, "
|
| 151 |
-
f"{super().extra_repr()}")
|
|
|
|
| 102 |
# Learnable gate bias per dimension — allows the model to learn which
|
| 103 |
# dimensions are safe to suppress vs which should always pass through
|
| 104 |
self.dim_gate = nn.Parameter(torch.zeros(1, 1, hidden_dim))
|
| 105 |
+
|
| 106 |
+
# Start as an exact no-op. Once this scalar moves away from zero during
|
| 107 |
+
# training, the confidence probe and dimension gate control suppression.
|
| 108 |
+
self.suppression_scale = nn.Parameter(torch.tensor(0.0))
|
| 109 |
|
| 110 |
# Running confidence for monitoring/logging
|
| 111 |
self.register_buffer("_last_confidence", torch.tensor(0.5), persistent=False)
|
|
|
|
| 115 |
self,
|
| 116 |
hidden_states: torch.Tensor,
|
| 117 |
layer_idx: int,
|
| 118 |
+
pre_layer_hidden: Optional[torch.Tensor] = None,
|
| 119 |
**kwargs
|
| 120 |
) -> torch.Tensor:
|
| 121 |
"""
|
|
|
|
| 136 |
# Compute per-dimension gate
|
| 137 |
dim_bias = torch.sigmoid(self.dim_gate) # [1, 1, D], in (0,1)
|
| 138 |
|
| 139 |
+
# Effective gate: combines token-level confidence with dimension-level bias.
|
| 140 |
+
# suppression_scale is zero at initialization, so this module is exactly
|
| 141 |
+
# transparent before Cortex-specific training.
|
| 142 |
# High dim_bias = dimension always passes through
|
| 143 |
# Low dim_bias = dimension is gated by confidence
|
| 144 |
+
effective_suppression = self.suppression_strength * self.suppression_scale
|
| 145 |
+
gate = 1.0 - effective_suppression * (1.0 - confidence) * (1.0 - dim_bias)
|
| 146 |
|
| 147 |
+
# Apply the gate to the layer update when the hook can provide the block
|
| 148 |
+
# input. Fallback to gating the stream itself for manually-called modules.
|
| 149 |
+
if pre_layer_hidden is not None and pre_layer_hidden.shape == hidden_states.shape:
|
| 150 |
+
residual_update = hidden_states - pre_layer_hidden
|
| 151 |
+
gated_output = pre_layer_hidden + gate * residual_update
|
| 152 |
+
else:
|
| 153 |
+
gated_output = hidden_states * gate
|
| 154 |
|
| 155 |
return gated_output
|
| 156 |
|
|
|
|
| 161 |
def extra_repr(self):
|
| 162 |
return (f"hidden_dim={self.hidden_dim}, "
|
| 163 |
f"suppression_strength={self.suppression_strength}, "
|
| 164 |
+
f"{super().extra_repr()}")
|
cortex/memory_bank.py
CHANGED
|
@@ -96,6 +96,10 @@ class MemoryBank(CortexModule):
|
|
| 96 |
def reset_memory(self, batch_size: int = 1):
|
| 97 |
"""Reset memory to initial state."""
|
| 98 |
self._memory_state = self.memory_init.expand(batch_size, -1, -1).clone()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
def forward(
|
| 101 |
self,
|
|
@@ -167,4 +171,4 @@ class MemoryBank(CortexModule):
|
|
| 167 |
|
| 168 |
def extra_repr(self):
|
| 169 |
return (f"hidden_dim={self.hidden_dim}, num_slots={self.num_slots}, "
|
| 170 |
-
f"num_heads={self.num_heads}, {super().extra_repr()}")
|
|
|
|
| 96 |
def reset_memory(self, batch_size: int = 1):
|
| 97 |
"""Reset memory to initial state."""
|
| 98 |
self._memory_state = self.memory_init.expand(batch_size, -1, -1).clone()
|
| 99 |
+
|
| 100 |
+
def reset_state(self, batch_size: int = 1):
|
| 101 |
+
"""Reset memory between independent benchmark examples."""
|
| 102 |
+
self.reset_memory(batch_size=batch_size)
|
| 103 |
|
| 104 |
def forward(
|
| 105 |
self,
|
|
|
|
| 171 |
|
| 172 |
def extra_repr(self):
|
| 173 |
return (f"hidden_dim={self.hidden_dim}, num_slots={self.num_slots}, "
|
| 174 |
+
f"num_heads={self.num_heads}, {super().extra_repr()}")
|
test_cortex.py
CHANGED
|
@@ -4,7 +4,7 @@ Verify that:
|
|
| 4 |
1. All modules inject without errors
|
| 5 |
2. Forward pass works
|
| 6 |
3. Gradients flow only through Cortex parameters
|
| 7 |
-
4.
|
| 8 |
5. Each module's specific functionality works
|
| 9 |
|
| 10 |
Usage:
|
|
@@ -114,4 +114,4 @@ def main():
|
|
| 114 |
print(f"{'='*60}")
|
| 115 |
|
| 116 |
if __name__ == "__main__":
|
| 117 |
-
main()
|
|
|
|
| 4 |
1. All modules inject without errors
|
| 5 |
2. Forward pass works
|
| 6 |
3. Gradients flow only through Cortex parameters
|
| 7 |
+
4. Freshly injected modules preserve base outputs
|
| 8 |
5. Each module's specific functionality works
|
| 9 |
|
| 10 |
Usage:
|
|
|
|
| 114 |
print(f"{'='*60}")
|
| 115 |
|
| 116 |
if __name__ == "__main__":
|
| 117 |
+
main()
|