cortex / test_cortex.py
theapemachine's picture
Enhance benchmark and Cortex modules with new training utilities and improved state management. Update README with example output for Llama-3.2-1B and add training CLI for Cortex module tuning. Refactor scoring functions to reset Cortex state between examples and ensure consistent output. Modify task handling to ensure proper formatting of input data.
0de2901
"""
Comprehensive test: Load a real pretrained LLM and inject all Cortex modules.
Verify that:
1. All modules inject without errors
2. Forward pass works
3. Gradients flow only through Cortex parameters
4. Freshly injected modules preserve base outputs
5. Each module's specific functionality works
Usage:
python test_cortex.py
"""
import torch
import sys
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from cortex.torch_device import resolve_torch_device
from cortex import (
CortexSurgeon,
MemoryBank,
HallucinationGate,
PauseAndThink,
BacktrackHead,
SteeringVector,
AdaptiveDepth
)
import logging
logging.basicConfig(level=logging.INFO, format="%(name)s | %(message)s")
def main():
device = resolve_torch_device("auto")
print(f"\n{'='*60}")
print(f"CORTEX TEST — Device: {device}")
print(f"{'='*60}\n")
model_name = "HuggingFaceTB/SmolLM2-135M"
print(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float32, device_map=device)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
hidden_dim = model.config.hidden_size
num_layers = model.config.num_hidden_layers
print(f"Model: hidden_dim={hidden_dim}, num_layers={num_layers}")
# Baseline
test_input = "The capital of France is"
inputs = tokenizer(test_input, return_tensors="pt").to(device)
with torch.no_grad():
baseline_logits = model(**inputs).logits.clone()
print(f"Baseline next token: '{tokenizer.decode(baseline_logits[0, -1].argmax())}'")
# Surgery
surgeon = CortexSurgeon(model)
num_heads = 4
while hidden_dim % num_heads != 0:
num_heads -= 1
middle_layers = list(range(num_layers // 3, 2 * num_layers // 3))
deep_layers = list(range(2 * num_layers // 3, num_layers))
surgeon.add_module("memory", MemoryBank(hidden_dim=hidden_dim, num_slots=16, num_heads=num_heads, target_layers=middle_layers))
surgeon.add_module("halluc_gate", HallucinationGate(hidden_dim=hidden_dim, bottleneck_dim=32, target_layers=deep_layers))
surgeon.add_module("pause_think", PauseAndThink(hidden_dim=hidden_dim, num_think_tokens=4, target_layers=middle_layers))
surgeon.add_module("backtrack", BacktrackHead(hidden_dim=hidden_dim, confidence_bottleneck=32, num_layers=num_layers, target_layers="all"))
surgeon.add_module("steering", SteeringVector(hidden_dim=hidden_dim, num_directions=2, direction_names=["truthfulness", "helpfulness"], target_layers=middle_layers))
surgeon.add_module("adaptive_depth", AdaptiveDepth(hidden_dim=hidden_dim, target_layers="all"))
surgeon.operate(freeze_base=True)
report = surgeon.get_parameter_report()
total_cortex = sum(info['trainable'] for info in report.values())
total_model = sum(p.numel() for p in model.parameters())
print(f"\nCortex: {total_cortex:,} params ({total_cortex/total_model*100:.2f}% overhead)")
# Forward with modules
with torch.no_grad():
enhanced_logits = model(**inputs).logits.clone()
logit_diff = (enhanced_logits - baseline_logits).abs().mean().item()
print(f"Logit diff with modules: {logit_diff:.6f}")
# Gradient check
output = model(**inputs, labels=inputs["input_ids"].clone())
output.loss.backward()
cortex_grads = sum(1 for p in surgeon.get_trainable_parameters() if p.grad is not None and p.grad.abs().sum() > 0)
print(f"Cortex params with gradients: {cortex_grads}")
# Enable/disable
for mod in surgeon.modules.values():
mod.disable()
with torch.no_grad():
disabled_diff = (model(**inputs).logits - baseline_logits).abs().mean().item()
print(f"Disabled vs baseline diff: {disabled_diff:.8f} (should be ~0)")
for mod in surgeon.modules.values():
mod.enable()
# Generation
gen_input = tokenizer("Once upon a time", return_tensors="pt").to(device)
with torch.no_grad():
gen_output = model.generate(**gen_input, max_new_tokens=50, do_sample=True, temperature=0.8, pad_token_id=tokenizer.pad_token_id)
print(f"Generated: {tokenizer.decode(gen_output[0], skip_special_tokens=True)[:200]}")
# Save
surgeon.save_cortex_modules("/tmp/cortex_modules.pt")
print(f"Saved: {os.path.getsize('/tmp/cortex_modules.pt') / 1024:.1f} KB")
print(f"\n{'='*60}")
print("ALL TESTS PASSED ✓")
print(f"{'='*60}")
if __name__ == "__main__":
main()