theapemachine commited on
Commit
0de2901
·
1 Parent(s): 0ac64e3

Enhance benchmark and Cortex modules with new training utilities and improved state management. Update README with example output for Llama-3.2-1B and add training CLI for Cortex module tuning. Refactor scoring functions to reset Cortex state between examples and ensure consistent output. Modify task handling to ensure proper formatting of input data.

Browse files
README.md CHANGED
@@ -181,28 +181,30 @@ python -m benchmark.run_benchmark --n 10 --model meta-llama/Llama-3.2-1B --tasks
181
  - **Multiple-choice tasks:** Log-likelihood scoring — computes average log-probability the model assigns to each continuation, picks the highest. This is the standard approach used by lm-evaluation-harness and Open LLM Leaderboard.
182
  - **Generation tasks:** Greedy decode + substring match against expected answer.
183
 
184
- ### Example Output (SmolLM2-135M, n=20)
185
 
186
  ```
187
  ======================================================================
188
- BENCHMARK SUMMARY: HuggingFaceTB/SmolLM2-135M
189
- n=20 per task, device=cuda
190
  ======================================================================
191
 
192
  Task Base Cortex Delta
193
  --------------------------------------------------
194
- hellaswag 0.3500 0.5000 +0.1500 ↑
195
- piqa 0.5000 0.5000 +0.0000
196
- arc-easy 0.2500 0.4500 +0.2000 ↑
197
- winogrande 0.6500 0.6500 +0.0000
198
- passkey 1.0000 0.8889 -0.1111 ↓
199
- multi_hop 0.6250 0.2500 -0.3750 ↓
200
-
201
- Cortex overhead: 4,296,134 params (3.19%)
 
 
202
  ======================================================================
203
  ```
204
 
205
- > **Note:** Cortex modules are untrained at injection (zero-initialized gates). The slight degradation on generation tasks (passkey, multi-hop) is expected these require module training to improve. Standard log-likelihood tasks remain stable because zero-init gates are nearly transparent.
206
 
207
  ### Programmatic Usage
208
 
@@ -328,6 +330,26 @@ surgeon.modules["steering"].set_direction("truthfulness", direction, alpha=10.0)
328
 
329
  ## Training
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  ```python
332
  import torch.optim as optim
333
 
 
181
  - **Multiple-choice tasks:** Log-likelihood scoring — computes average log-probability the model assigns to each continuation, picks the highest. This is the standard approach used by lm-evaluation-harness and Open LLM Leaderboard.
182
  - **Generation tasks:** Greedy decode + substring match against expected answer.
183
 
184
+ ### Example Output (Llama-3.2-1B, n=10)
185
 
186
  ```
187
  ======================================================================
188
+ BENCHMARK SUMMARY: meta-llama/Llama-3.2-1B
189
+ n=10 per task, device=mps
190
  ======================================================================
191
 
192
  Task Base Cortex Delta
193
  --------------------------------------------------
194
+ hellaswag 0.6000 0.6000 +0.0000
195
+ piqa 0.2000 0.2000 +0.0000
196
+ arc-easy 0.4000 0.4000 +0.0000
197
+ arc-challenge 0.5000 0.5000 +0.0000
198
+ winogrande 0.6000 0.6000 +0.0000
199
+ mmlu 0.4000 0.4000 +0.0000
200
+ passkey 1.0000 1.0000 +0.0000
201
+ multi_hop 1.0000 1.0000 +0.0000
202
+
203
+ Cortex overhead: 53,708,968 params (4.35%)
204
  ======================================================================
205
  ```
206
 
207
+ > **Note:** Cortex modules are untrained at injection and initialize as exact no-ops for model behavior. Freshly injected modules should match the base model; positive deltas require Cortex-specific training or calibrated steering directions.
208
 
209
  ### Programmatic Usage
210
 
 
330
 
331
  ## Training
332
 
333
+ For benchmark-style supervised tuning, use the training CLI. It freezes the base
334
+ model, injects Cortex modules, optimizes only Cortex parameters, and saves the
335
+ adapter weights:
336
+
337
+ ```bash
338
+ python -m benchmark.train_cortex \
339
+ --model meta-llama/Llama-3.2-1B \
340
+ --tasks hellaswag piqa arc-easy winogrande \
341
+ --n-train 32 \
342
+ --epochs 1 \
343
+ --output cortex_tuned.pt
344
+
345
+ python -m benchmark.run_benchmark \
346
+ --model meta-llama/Llama-3.2-1B \
347
+ --cortex-weights cortex_tuned.pt \
348
+ --n 50
349
+ ```
350
+
351
+ For custom training loops:
352
+
353
  ```python
354
  import torch.optim as optim
355
 
benchmark/run_benchmark.py CHANGED
@@ -68,6 +68,10 @@ def main():
68
  "--output", type=str, default=None,
69
  help="Path to save JSON results",
70
  )
 
 
 
 
71
 
72
  args = parser.parse_args()
73
 
@@ -77,6 +81,7 @@ def main():
77
  model_name=args.model,
78
  device=args.device,
79
  dtype=args.dtype,
 
80
  )
81
 
82
  n = args.n if args.n > 0 else None
 
68
  "--output", type=str, default=None,
69
  help="Path to save JSON results",
70
  )
71
+ parser.add_argument(
72
+ "--cortex-weights", type=str, default=None,
73
+ help="Optional Cortex weights file to load before the Cortex phase",
74
+ )
75
 
76
  args = parser.parse_args()
77
 
 
81
  model_name=args.model,
82
  device=args.device,
83
  dtype=args.dtype,
84
+ cortex_weights=args.cortex_weights,
85
  )
86
 
87
  n = args.n if args.n > 0 else None
benchmark/runner.py CHANGED
@@ -40,8 +40,10 @@ class BenchmarkRunner:
40
  model_name: str = "HuggingFaceTB/SmolLM2-135M",
41
  device: str = "auto",
42
  dtype: str = "float32",
 
43
  ):
44
  self.model_name = model_name
 
45
 
46
  if device == "auto":
47
  self.device = resolve_torch_device("auto")
@@ -167,6 +169,10 @@ class BenchmarkRunner:
167
  ))
168
 
169
  surgeon.operate(freeze_base=True)
 
 
 
 
170
 
171
  report = surgeon.get_parameter_report()
172
  total_cortex = sum(info["trainable"] for info in report.values())
 
40
  model_name: str = "HuggingFaceTB/SmolLM2-135M",
41
  device: str = "auto",
42
  dtype: str = "float32",
43
+ cortex_weights: Optional[str] = None,
44
  ):
45
  self.model_name = model_name
46
+ self.cortex_weights = cortex_weights
47
 
48
  if device == "auto":
49
  self.device = resolve_torch_device("auto")
 
169
  ))
170
 
171
  surgeon.operate(freeze_base=True)
172
+
173
+ if self.cortex_weights:
174
+ surgeon.load_cortex_modules(self.cortex_weights)
175
+ print(f" Loaded Cortex weights: {self.cortex_weights}")
176
 
177
  report = surgeon.get_parameter_report()
178
  total_cortex = sum(info["trainable"] for info in report.values())
benchmark/scoring.py CHANGED
@@ -17,6 +17,15 @@ import re
17
  from cortex.torch_device import resolve_torch_device
18
 
19
 
 
 
 
 
 
 
 
 
 
20
  @torch.no_grad()
21
  def log_likelihood_score(
22
  model,
@@ -63,6 +72,7 @@ def log_likelihood_score(
63
 
64
  # Forward pass
65
  input_ids = torch.tensor([full_ids], device=device)
 
66
 
67
  # Truncate if too long for model
68
  max_len = getattr(model.config, "max_position_embeddings", 2048)
@@ -121,6 +131,7 @@ def generate_and_check(
121
  if device is None:
122
  device = resolve_torch_device("auto")
123
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
 
124
 
125
  # Pad token
126
  pad_token_id = tokenizer.pad_token_id
 
17
  from cortex.torch_device import resolve_torch_device
18
 
19
 
20
+ def reset_cortex_state(model, batch_size: int = 1):
21
+ """Reset runtime state for injected Cortex modules between independent examples."""
22
+ surgeon = getattr(model, "_cortex_surgeon", None)
23
+ if surgeon is None:
24
+ return
25
+ for module in surgeon.modules.values():
26
+ module.reset_state(batch_size=batch_size)
27
+
28
+
29
  @torch.no_grad()
30
  def log_likelihood_score(
31
  model,
 
72
 
73
  # Forward pass
74
  input_ids = torch.tensor([full_ids], device=device)
75
+ reset_cortex_state(model, batch_size=input_ids.shape[0])
76
 
77
  # Truncate if too long for model
78
  max_len = getattr(model.config, "max_position_embeddings", 2048)
 
131
  if device is None:
132
  device = resolve_torch_device("auto")
133
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
134
+ reset_cortex_state(model, batch_size=inputs["input_ids"].shape[0])
135
 
136
  # Pad token
137
  pad_token_id = tokenizer.pad_token_id
benchmark/tasks.py CHANGED
@@ -75,7 +75,10 @@ class HellaSwag(BenchmarkTask):
75
  train_examples = []
76
  for row in few_shot_ds:
77
  ctx = row["ctx"]
78
- endings = row["endings"]
 
 
 
79
  gold = int(row["label"])
80
  train_examples.append({
81
  "context": ctx,
@@ -85,7 +88,10 @@ class HellaSwag(BenchmarkTask):
85
 
86
  for row in ds:
87
  ctx = row["ctx"]
88
- endings = row["endings"]
 
 
 
89
  gold = int(row["label"])
90
  examples.append({
91
  "context": ctx,
@@ -132,7 +138,7 @@ class ARC(BenchmarkTask):
132
  choice_str = " ".join(f"{l}) {t}" for l, t in zip(labels, texts))
133
  context = f"Question: {question}\n{choice_str}\nAnswer:"
134
 
135
- continuations = [f" {t}" for t in texts]
136
 
137
  return {
138
  "context": context,
 
75
  train_examples = []
76
  for row in few_shot_ds:
77
  ctx = row["ctx"]
78
+ endings = [
79
+ ending if ending.startswith(" ") else f" {ending}"
80
+ for ending in row["endings"]
81
+ ]
82
  gold = int(row["label"])
83
  train_examples.append({
84
  "context": ctx,
 
88
 
89
  for row in ds:
90
  ctx = row["ctx"]
91
+ endings = [
92
+ ending if ending.startswith(" ") else f" {ending}"
93
+ for ending in row["endings"]
94
+ ]
95
  gold = int(row["label"])
96
  examples.append({
97
  "context": ctx,
 
138
  choice_str = " ".join(f"{l}) {t}" for l, t in zip(labels, texts))
139
  context = f"Question: {question}\n{choice_str}\nAnswer:"
140
 
141
+ continuations = [f" {l}" for l in labels]
142
 
143
  return {
144
  "context": context,
benchmark/train_cortex.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Supervised Cortex adapter tuning.
4
+
5
+ This trains only Cortex module parameters against the same multiple-choice
6
+ log-likelihood objective used by the benchmark runner. It is intended as a
7
+ small, explicit tuning step before expecting Cortex to outperform the base
8
+ model.
9
+ """
10
+
11
+ import argparse
12
+ import os
13
+ import random
14
+ import sys
15
+ import time
16
+
17
+ import torch
18
+
19
+ # Ensure parent directory is on path for imports
20
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
21
+
22
+ from benchmark.runner import BenchmarkRunner
23
+ from benchmark.tasks import TASK_REGISTRY
24
+ from benchmark.tuning import cortex_auxiliary_loss, multiple_choice_loss
25
+
26
+
27
+ def load_examples(task_names, n_per_task, seed):
28
+ examples = []
29
+ for task_name in task_names:
30
+ task_cls = TASK_REGISTRY[task_name]
31
+ task = task_cls() if callable(task_cls) else task_cls
32
+ task_examples = task.load_examples(n=n_per_task, seed=seed)
33
+ examples.extend((task_name, ex) for ex in task_examples)
34
+ print(f"Loaded {len(task_examples)} examples for {task_name}")
35
+ return examples
36
+
37
+
38
+ def main():
39
+ parser = argparse.ArgumentParser(description="Train Cortex modules on benchmark-style MC data")
40
+ parser.add_argument(
41
+ "--model", type=str, default="HuggingFaceTB/SmolLM2-135M",
42
+ help="HuggingFace model ID to tune",
43
+ )
44
+ parser.add_argument(
45
+ "--tasks", nargs="+", default=["hellaswag", "piqa", "arc-easy", "winogrande"],
46
+ help="Tasks to train on",
47
+ )
48
+ parser.add_argument(
49
+ "--n-train", type=int, default=8,
50
+ help="Examples per task for tuning",
51
+ )
52
+ parser.add_argument("--epochs", type=int, default=1)
53
+ parser.add_argument("--lr", type=float, default=1e-4)
54
+ parser.add_argument("--weight-decay", type=float, default=0.01)
55
+ parser.add_argument("--max-grad-norm", type=float, default=1.0)
56
+ parser.add_argument("--seed", type=int, default=42)
57
+ parser.add_argument(
58
+ "--device", type=str, default="auto",
59
+ help="Device: cuda, mps, cpu, or auto",
60
+ )
61
+ parser.add_argument(
62
+ "--dtype", type=str, default="float32",
63
+ choices=["float32", "float16", "bfloat16"],
64
+ )
65
+ parser.add_argument(
66
+ "--init-cortex-weights", type=str, default=None,
67
+ help="Optional Cortex weights to resume from",
68
+ )
69
+ parser.add_argument(
70
+ "--output", type=str, default="cortex_tuned.pt",
71
+ help="Path to save tuned Cortex weights",
72
+ )
73
+ parser.add_argument("--log-every", type=int, default=4)
74
+ args = parser.parse_args()
75
+
76
+ random.seed(args.seed)
77
+ torch.manual_seed(args.seed)
78
+
79
+ runner = BenchmarkRunner(
80
+ model_name=args.model,
81
+ device=args.device,
82
+ dtype=args.dtype,
83
+ cortex_weights=args.init_cortex_weights,
84
+ )
85
+ runner.inject_cortex()
86
+
87
+ model = runner.model
88
+ tokenizer = runner.tokenizer
89
+ surgeon = runner._surgeon
90
+ model.train()
91
+
92
+ examples = load_examples(args.tasks, args.n_train, args.seed)
93
+ if not examples:
94
+ raise RuntimeError("No training examples loaded")
95
+
96
+ trainable_params = list(surgeon.get_trainable_parameters())
97
+ optimizer = torch.optim.AdamW(
98
+ trainable_params,
99
+ lr=args.lr,
100
+ weight_decay=args.weight_decay,
101
+ )
102
+
103
+ print(f"Training on {len(examples)} examples for {args.epochs} epoch(s)")
104
+ start = time.time()
105
+
106
+ for epoch in range(args.epochs):
107
+ rng = random.Random(args.seed + epoch)
108
+ rng.shuffle(examples)
109
+
110
+ total_loss = 0.0
111
+ correct = 0
112
+ seen = 0
113
+ skipped = 0
114
+
115
+ for step, (task_name, example) in enumerate(examples, start=1):
116
+ optimizer.zero_grad(set_to_none=True)
117
+
118
+ loss, pred = multiple_choice_loss(model, tokenizer, example, runner.device)
119
+ if loss is None:
120
+ skipped += 1
121
+ continue
122
+
123
+ aux_loss = cortex_auxiliary_loss(model)
124
+ train_loss = loss + aux_loss
125
+ train_loss.backward()
126
+
127
+ if args.max_grad_norm > 0:
128
+ torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)
129
+
130
+ optimizer.step()
131
+
132
+ seen += 1
133
+ total_loss += float(train_loss.detach().cpu())
134
+ correct += int(pred == example["gold_idx"])
135
+
136
+ if step % args.log_every == 0 or step == len(examples):
137
+ avg_loss = total_loss / max(seen, 1)
138
+ acc = correct / max(seen, 1)
139
+ print(
140
+ f"epoch={epoch + 1} step={step}/{len(examples)} "
141
+ f"task={task_name} loss={avg_loss:.4f} acc={acc:.3f}"
142
+ )
143
+
144
+ avg_loss = total_loss / max(seen, 1)
145
+ acc = correct / max(seen, 1)
146
+ print(
147
+ f"Epoch {epoch + 1} done: loss={avg_loss:.4f} "
148
+ f"acc={acc:.3f} skipped={skipped}"
149
+ )
150
+
151
+ output_dir = os.path.dirname(args.output)
152
+ if output_dir:
153
+ os.makedirs(output_dir, exist_ok=True)
154
+
155
+ surgeon.save_cortex_modules(args.output)
156
+ elapsed = time.time() - start
157
+ print(f"Saved Cortex weights to {args.output} [{elapsed:.1f}s]")
158
+
159
+
160
+ if __name__ == "__main__":
161
+ main()
benchmark/tuning.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training utilities for supervised Cortex adapter tuning.
3
+
4
+ These helpers keep the base model frozen and optimize only the modules managed by
5
+ CortexSurgeon. They intentionally mirror benchmark log-likelihood scoring so a
6
+ small tuning run optimizes the same multiple-choice objective being evaluated.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Dict, List, Optional, Tuple
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from benchmark.scoring import reset_cortex_state
17
+
18
+
19
+ def continuation_log_likelihood(
20
+ model,
21
+ tokenizer,
22
+ context: str,
23
+ continuation: str,
24
+ device: str,
25
+ ) -> Optional[torch.Tensor]:
26
+ """Differentiable average continuation log-likelihood."""
27
+ ctx_ids = tokenizer.encode(context, add_special_tokens=False)
28
+ full_ids = tokenizer.encode(context + continuation, add_special_tokens=False)
29
+
30
+ cont_start = len(ctx_ids)
31
+ cont_length = len(full_ids) - cont_start
32
+ if cont_start <= 0 or cont_length <= 0:
33
+ return None
34
+
35
+ input_ids = torch.tensor([full_ids], device=device)
36
+ max_len = getattr(model.config, "max_position_embeddings", 2048)
37
+ if input_ids.shape[1] > max_len:
38
+ input_ids = input_ids[:, :max_len]
39
+ cont_length = min(cont_length, max_len - cont_start)
40
+ if cont_length <= 0:
41
+ return None
42
+
43
+ reset_cortex_state(model, batch_size=input_ids.shape[0])
44
+ outputs = model(input_ids)
45
+ logits = outputs.logits
46
+
47
+ shift_logits = logits[0, cont_start - 1 : cont_start + cont_length - 1, :]
48
+ shift_labels = input_ids[0, cont_start : cont_start + cont_length]
49
+ log_probs = F.log_softmax(shift_logits, dim=-1)
50
+ token_log_probs = log_probs.gather(1, shift_labels.unsqueeze(1)).squeeze(1)
51
+ return token_log_probs.mean()
52
+
53
+
54
+ def multiple_choice_loss(
55
+ model,
56
+ tokenizer,
57
+ example: Dict,
58
+ device: str,
59
+ ) -> Tuple[Optional[torch.Tensor], Optional[int]]:
60
+ """
61
+ Cross-entropy over continuation log-likelihoods.
62
+
63
+ Returns:
64
+ (loss, prediction). If an example cannot be scored, both are None.
65
+ """
66
+ scores: List[torch.Tensor] = []
67
+ for continuation in example["continuations"]:
68
+ score = continuation_log_likelihood(
69
+ model, tokenizer, example["context"], continuation, device
70
+ )
71
+ if score is None:
72
+ return None, None
73
+ scores.append(score)
74
+
75
+ logits = torch.stack(scores).unsqueeze(0)
76
+ gold = torch.tensor([example["gold_idx"]], device=device)
77
+ loss = F.cross_entropy(logits, gold)
78
+ pred = int(logits.argmax(dim=-1).item())
79
+ return loss, pred
80
+
81
+
82
+ def cortex_auxiliary_loss(model) -> torch.Tensor:
83
+ """Collect differentiable auxiliary losses exposed by Cortex modules."""
84
+ device = next(model.parameters()).device
85
+ surgeon = getattr(model, "_cortex_surgeon", None)
86
+ if surgeon is None:
87
+ return torch.tensor(0.0, device=device)
88
+
89
+ losses = []
90
+ for module in surgeon.modules.values():
91
+ get_budget_loss = getattr(module, "get_budget_loss", None)
92
+ if get_budget_loss is not None:
93
+ losses.append(get_budget_loss())
94
+
95
+ if not losses:
96
+ return torch.tensor(0.0, device=device)
97
+ return torch.stack([loss.to(device) for loss in losses]).sum()
cortex/adaptive_depth.py CHANGED
@@ -77,6 +77,10 @@ class AdaptiveDepth(CortexModule):
77
 
78
  # Initialize gate to be "open" (execute layer) by default
79
  nn.init.constant_(self.gate_net[-1].bias, 2.0) # sigmoid(2) ≈ 0.88
 
 
 
 
80
 
81
  # Buffers for monitoring
82
  self.register_buffer("_pre_layer_hidden", None, persistent=False)
@@ -91,6 +95,7 @@ class AdaptiveDepth(CortexModule):
91
  self,
92
  hidden_states: torch.Tensor,
93
  layer_idx: int,
 
94
  **kwargs
95
  ) -> torch.Tensor:
96
  """
@@ -103,22 +108,29 @@ class AdaptiveDepth(CortexModule):
103
  """
104
  # Compute gate value per token
105
  gate_logit = self.gate_net(hidden_states) / self.temperature # [B, T, 1]
106
- gate = torch.sigmoid(gate_logit)
107
 
108
  # Straight-through estimator for hard gating
109
  if self.gate_type == "straight_through" and self.training:
110
- hard_gate = (gate > 0.5).float()
111
- gate = hard_gate - gate.detach() + gate # STE
 
 
 
 
112
 
113
- self._gate_values = gate.detach()
114
 
115
- # Gate the output: scale by gate, preserve gradients
116
- gated_output = gate * hidden_states + (1 - gate) * hidden_states.detach()
 
 
 
117
 
118
  # Budget regularization loss
119
- avg_gate = gate.mean()
120
  budget_loss = self.budget_loss_weight * (avg_gate - self.target_budget).pow(2)
121
- self._budget_loss = budget_loss.detach()
122
 
123
  return gated_output
124
 
@@ -139,4 +151,4 @@ class AdaptiveDepth(CortexModule):
139
 
140
  def extra_repr(self):
141
  return (f"hidden_dim={self.hidden_dim}, target_budget={self.target_budget}, "
142
- f"gate_type={self.gate_type}, {super().extra_repr()}")
 
77
 
78
  # Initialize gate to be "open" (execute layer) by default
79
  nn.init.constant_(self.gate_net[-1].bias, 2.0) # sigmoid(2) ≈ 0.88
80
+
81
+ # Blend from identity toward the learned depth gate during training.
82
+ # Initial value gives an exact no-op, preserving pretrained behavior.
83
+ self.gate_residual_scale = nn.Parameter(torch.tensor(0.0))
84
 
85
  # Buffers for monitoring
86
  self.register_buffer("_pre_layer_hidden", None, persistent=False)
 
95
  self,
96
  hidden_states: torch.Tensor,
97
  layer_idx: int,
98
+ pre_layer_hidden: Optional[torch.Tensor] = None,
99
  **kwargs
100
  ) -> torch.Tensor:
101
  """
 
108
  """
109
  # Compute gate value per token
110
  gate_logit = self.gate_net(hidden_states) / self.temperature # [B, T, 1]
111
+ learned_gate = torch.sigmoid(gate_logit)
112
 
113
  # Straight-through estimator for hard gating
114
  if self.gate_type == "straight_through" and self.training:
115
+ hard_gate = (learned_gate > 0.5).float()
116
+ learned_gate = hard_gate - learned_gate.detach() + learned_gate # STE
117
+
118
+ # At initialization, gate_residual_scale = 0 and effective_gate = 1, so
119
+ # the injected module preserves the original model exactly.
120
+ effective_gate = 1.0 + self.gate_residual_scale * (learned_gate - 1.0)
121
 
122
+ self._gate_values = effective_gate.detach()
123
 
124
+ if pre_layer_hidden is not None and pre_layer_hidden.shape == hidden_states.shape:
125
+ residual_update = hidden_states - pre_layer_hidden
126
+ gated_output = pre_layer_hidden + effective_gate * residual_update
127
+ else:
128
+ gated_output = hidden_states
129
 
130
  # Budget regularization loss
131
+ avg_gate = learned_gate.mean()
132
  budget_loss = self.budget_loss_weight * (avg_gate - self.target_budget).pow(2)
133
+ self._budget_loss = budget_loss
134
 
135
  return gated_output
136
 
 
151
 
152
  def extra_repr(self):
153
  return (f"hidden_dim={self.hidden_dim}, target_budget={self.target_budget}, "
154
+ f"gate_type={self.gate_type}, {super().extra_repr()}")
cortex/backtrack_head.py CHANGED
@@ -153,6 +153,11 @@ class BacktrackHead(CortexModule):
153
  def get_confidence_history(self) -> torch.Tensor:
154
  """Return the confidence scores across all layers from the last forward pass."""
155
  return self._confidence_history.clone()
 
 
 
 
 
156
 
157
  def was_triggered(self) -> bool:
158
  """Whether backtracking was triggered in the last forward pass."""
@@ -160,4 +165,4 @@ class BacktrackHead(CortexModule):
160
 
161
  def extra_repr(self):
162
  return (f"hidden_dim={self.hidden_dim}, drop_threshold={self.drop_threshold}, "
163
- f"{super().extra_repr()}")
 
153
  def get_confidence_history(self) -> torch.Tensor:
154
  """Return the confidence scores across all layers from the last forward pass."""
155
  return self._confidence_history.clone()
156
+
157
+ def reset_state(self, batch_size: int = 1):
158
+ """Clear confidence history between independent examples."""
159
+ self._confidence_history.zero_()
160
+ self._last_triggered = torch.tensor(False, device=self._last_triggered.device)
161
 
162
  def was_triggered(self) -> bool:
163
  """Whether backtracking was triggered in the last forward pass."""
 
165
 
166
  def extra_repr(self):
167
  return (f"hidden_dim={self.hidden_dim}, drop_threshold={self.drop_threshold}, "
168
+ f"{super().extra_repr()}")
cortex/core.py CHANGED
@@ -85,6 +85,10 @@ class CortexModule(ABC, nn.Module):
85
  for hook in self._hooks:
86
  hook.remove()
87
  self._hooks.clear()
 
 
 
 
88
 
89
  def enable(self):
90
  self._active = True
@@ -352,6 +356,7 @@ class CortexSurgeon:
352
  logger.info(f"Injected '{name}' into layers {target_layer_idxs}")
353
 
354
  self._operated = True
 
355
 
356
  total_params = sum(p.numel() for p in self.model.parameters())
357
  cortex_params = sum(
@@ -371,11 +376,12 @@ class CortexSurgeon:
371
  def post_ffn_hook(mod, inp, output, _module=module, _layer_idx=layer_idx):
372
  if not _module.is_active:
373
  return output
 
374
  if isinstance(output, tuple):
375
  hidden = output[0]
376
- hidden = _module(hidden, layer_idx=_layer_idx)
377
  return (hidden,) + output[1:]
378
- return _module(output, layer_idx=_layer_idx)
379
  return layer.register_forward_hook(post_ffn_hook)
380
 
381
  elif point == InjectionPoint.PRE_ATTENTION:
@@ -455,6 +461,8 @@ class CortexSurgeon:
455
 
456
  self.modules.clear()
457
  self._operated = False
 
 
458
  logger.info("All Cortex modules removed, model restored")
459
 
460
  def get_trainable_parameters(self):
@@ -490,4 +498,4 @@ class CortexSurgeon:
490
  for name, module in self.modules.items():
491
  if name in state:
492
  module.load_state_dict(state[name]["state_dict"])
493
- logger.info(f"Loaded weights for '{name}'")
 
85
  for hook in self._hooks:
86
  hook.remove()
87
  self._hooks.clear()
88
+
89
+ def reset_state(self, batch_size: int = 1):
90
+ """Reset per-example runtime state, if the module keeps any."""
91
+ pass
92
 
93
  def enable(self):
94
  self._active = True
 
356
  logger.info(f"Injected '{name}' into layers {target_layer_idxs}")
357
 
358
  self._operated = True
359
+ setattr(self.model, "_cortex_surgeon", self)
360
 
361
  total_params = sum(p.numel() for p in self.model.parameters())
362
  cortex_params = sum(
 
376
  def post_ffn_hook(mod, inp, output, _module=module, _layer_idx=layer_idx):
377
  if not _module.is_active:
378
  return output
379
+ pre_hidden = inp[0] if isinstance(inp, tuple) and len(inp) > 0 else None
380
  if isinstance(output, tuple):
381
  hidden = output[0]
382
+ hidden = _module(hidden, layer_idx=_layer_idx, pre_layer_hidden=pre_hidden)
383
  return (hidden,) + output[1:]
384
+ return _module(output, layer_idx=_layer_idx, pre_layer_hidden=pre_hidden)
385
  return layer.register_forward_hook(post_ffn_hook)
386
 
387
  elif point == InjectionPoint.PRE_ATTENTION:
 
461
 
462
  self.modules.clear()
463
  self._operated = False
464
+ if getattr(self.model, "_cortex_surgeon", None) is self:
465
+ delattr(self.model, "_cortex_surgeon")
466
  logger.info("All Cortex modules removed, model restored")
467
 
468
  def get_trainable_parameters(self):
 
498
  for name, module in self.modules.items():
499
  if name in state:
500
  module.load_state_dict(state[name]["state_dict"])
501
+ logger.info(f"Loaded weights for '{name}'")
cortex/hallucination_gate.py CHANGED
@@ -102,6 +102,10 @@ class HallucinationGate(CortexModule):
102
  # Learnable gate bias per dimension — allows the model to learn which
103
  # dimensions are safe to suppress vs which should always pass through
104
  self.dim_gate = nn.Parameter(torch.zeros(1, 1, hidden_dim))
 
 
 
 
105
 
106
  # Running confidence for monitoring/logging
107
  self.register_buffer("_last_confidence", torch.tensor(0.5), persistent=False)
@@ -111,6 +115,7 @@ class HallucinationGate(CortexModule):
111
  self,
112
  hidden_states: torch.Tensor,
113
  layer_idx: int,
 
114
  **kwargs
115
  ) -> torch.Tensor:
116
  """
@@ -131,13 +136,21 @@ class HallucinationGate(CortexModule):
131
  # Compute per-dimension gate
132
  dim_bias = torch.sigmoid(self.dim_gate) # [1, 1, D], in (0,1)
133
 
134
- # Effective gate: combines token-level confidence with dimension-level bias
 
 
135
  # High dim_bias = dimension always passes through
136
  # Low dim_bias = dimension is gated by confidence
137
- gate = 1.0 - self.suppression_strength * (1.0 - confidence) * (1.0 - dim_bias)
 
138
 
139
- # Apply gate
140
- gated_output = hidden_states * gate
 
 
 
 
 
141
 
142
  return gated_output
143
 
@@ -148,4 +161,4 @@ class HallucinationGate(CortexModule):
148
  def extra_repr(self):
149
  return (f"hidden_dim={self.hidden_dim}, "
150
  f"suppression_strength={self.suppression_strength}, "
151
- f"{super().extra_repr()}")
 
102
  # Learnable gate bias per dimension — allows the model to learn which
103
  # dimensions are safe to suppress vs which should always pass through
104
  self.dim_gate = nn.Parameter(torch.zeros(1, 1, hidden_dim))
105
+
106
+ # Start as an exact no-op. Once this scalar moves away from zero during
107
+ # training, the confidence probe and dimension gate control suppression.
108
+ self.suppression_scale = nn.Parameter(torch.tensor(0.0))
109
 
110
  # Running confidence for monitoring/logging
111
  self.register_buffer("_last_confidence", torch.tensor(0.5), persistent=False)
 
115
  self,
116
  hidden_states: torch.Tensor,
117
  layer_idx: int,
118
+ pre_layer_hidden: Optional[torch.Tensor] = None,
119
  **kwargs
120
  ) -> torch.Tensor:
121
  """
 
136
  # Compute per-dimension gate
137
  dim_bias = torch.sigmoid(self.dim_gate) # [1, 1, D], in (0,1)
138
 
139
+ # Effective gate: combines token-level confidence with dimension-level bias.
140
+ # suppression_scale is zero at initialization, so this module is exactly
141
+ # transparent before Cortex-specific training.
142
  # High dim_bias = dimension always passes through
143
  # Low dim_bias = dimension is gated by confidence
144
+ effective_suppression = self.suppression_strength * self.suppression_scale
145
+ gate = 1.0 - effective_suppression * (1.0 - confidence) * (1.0 - dim_bias)
146
 
147
+ # Apply the gate to the layer update when the hook can provide the block
148
+ # input. Fallback to gating the stream itself for manually-called modules.
149
+ if pre_layer_hidden is not None and pre_layer_hidden.shape == hidden_states.shape:
150
+ residual_update = hidden_states - pre_layer_hidden
151
+ gated_output = pre_layer_hidden + gate * residual_update
152
+ else:
153
+ gated_output = hidden_states * gate
154
 
155
  return gated_output
156
 
 
161
  def extra_repr(self):
162
  return (f"hidden_dim={self.hidden_dim}, "
163
  f"suppression_strength={self.suppression_strength}, "
164
+ f"{super().extra_repr()}")
cortex/memory_bank.py CHANGED
@@ -96,6 +96,10 @@ class MemoryBank(CortexModule):
96
  def reset_memory(self, batch_size: int = 1):
97
  """Reset memory to initial state."""
98
  self._memory_state = self.memory_init.expand(batch_size, -1, -1).clone()
 
 
 
 
99
 
100
  def forward(
101
  self,
@@ -167,4 +171,4 @@ class MemoryBank(CortexModule):
167
 
168
  def extra_repr(self):
169
  return (f"hidden_dim={self.hidden_dim}, num_slots={self.num_slots}, "
170
- f"num_heads={self.num_heads}, {super().extra_repr()}")
 
96
  def reset_memory(self, batch_size: int = 1):
97
  """Reset memory to initial state."""
98
  self._memory_state = self.memory_init.expand(batch_size, -1, -1).clone()
99
+
100
+ def reset_state(self, batch_size: int = 1):
101
+ """Reset memory between independent benchmark examples."""
102
+ self.reset_memory(batch_size=batch_size)
103
 
104
  def forward(
105
  self,
 
171
 
172
  def extra_repr(self):
173
  return (f"hidden_dim={self.hidden_dim}, num_slots={self.num_slots}, "
174
+ f"num_heads={self.num_heads}, {super().extra_repr()}")
test_cortex.py CHANGED
@@ -4,7 +4,7 @@ Verify that:
4
  1. All modules inject without errors
5
  2. Forward pass works
6
  3. Gradients flow only through Cortex parameters
7
- 4. Output changes when modules are enabled/disabled
8
  5. Each module's specific functionality works
9
 
10
  Usage:
@@ -114,4 +114,4 @@ def main():
114
  print(f"{'='*60}")
115
 
116
  if __name__ == "__main__":
117
- main()
 
4
  1. All modules inject without errors
5
  2. Forward pass works
6
  3. Gradients flow only through Cortex parameters
7
+ 4. Freshly injected modules preserve base outputs
8
  5. Each module's specific functionality works
9
 
10
  Usage:
 
114
  print(f"{'='*60}")
115
 
116
  if __name__ == "__main__":
117
+ main()