File size: 4,612 Bytes
ba8ab4a 0de2901 ba8ab4a 0ac64e3 ba8ab4a 0ac64e3 ba8ab4a 0de2901 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | """
Comprehensive test: Load a real pretrained LLM and inject all Cortex modules.
Verify that:
1. All modules inject without errors
2. Forward pass works
3. Gradients flow only through Cortex parameters
4. Freshly injected modules preserve base outputs
5. Each module's specific functionality works
Usage:
python test_cortex.py
"""
import torch
import sys
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from cortex.torch_device import resolve_torch_device
from cortex import (
CortexSurgeon,
MemoryBank,
HallucinationGate,
PauseAndThink,
BacktrackHead,
SteeringVector,
AdaptiveDepth
)
import logging
logging.basicConfig(level=logging.INFO, format="%(name)s | %(message)s")
def main():
device = resolve_torch_device("auto")
print(f"\n{'='*60}")
print(f"CORTEX TEST — Device: {device}")
print(f"{'='*60}\n")
model_name = "HuggingFaceTB/SmolLM2-135M"
print(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float32, device_map=device)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
hidden_dim = model.config.hidden_size
num_layers = model.config.num_hidden_layers
print(f"Model: hidden_dim={hidden_dim}, num_layers={num_layers}")
# Baseline
test_input = "The capital of France is"
inputs = tokenizer(test_input, return_tensors="pt").to(device)
with torch.no_grad():
baseline_logits = model(**inputs).logits.clone()
print(f"Baseline next token: '{tokenizer.decode(baseline_logits[0, -1].argmax())}'")
# Surgery
surgeon = CortexSurgeon(model)
num_heads = 4
while hidden_dim % num_heads != 0:
num_heads -= 1
middle_layers = list(range(num_layers // 3, 2 * num_layers // 3))
deep_layers = list(range(2 * num_layers // 3, num_layers))
surgeon.add_module("memory", MemoryBank(hidden_dim=hidden_dim, num_slots=16, num_heads=num_heads, target_layers=middle_layers))
surgeon.add_module("halluc_gate", HallucinationGate(hidden_dim=hidden_dim, bottleneck_dim=32, target_layers=deep_layers))
surgeon.add_module("pause_think", PauseAndThink(hidden_dim=hidden_dim, num_think_tokens=4, target_layers=middle_layers))
surgeon.add_module("backtrack", BacktrackHead(hidden_dim=hidden_dim, confidence_bottleneck=32, num_layers=num_layers, target_layers="all"))
surgeon.add_module("steering", SteeringVector(hidden_dim=hidden_dim, num_directions=2, direction_names=["truthfulness", "helpfulness"], target_layers=middle_layers))
surgeon.add_module("adaptive_depth", AdaptiveDepth(hidden_dim=hidden_dim, target_layers="all"))
surgeon.operate(freeze_base=True)
report = surgeon.get_parameter_report()
total_cortex = sum(info['trainable'] for info in report.values())
total_model = sum(p.numel() for p in model.parameters())
print(f"\nCortex: {total_cortex:,} params ({total_cortex/total_model*100:.2f}% overhead)")
# Forward with modules
with torch.no_grad():
enhanced_logits = model(**inputs).logits.clone()
logit_diff = (enhanced_logits - baseline_logits).abs().mean().item()
print(f"Logit diff with modules: {logit_diff:.6f}")
# Gradient check
output = model(**inputs, labels=inputs["input_ids"].clone())
output.loss.backward()
cortex_grads = sum(1 for p in surgeon.get_trainable_parameters() if p.grad is not None and p.grad.abs().sum() > 0)
print(f"Cortex params with gradients: {cortex_grads}")
# Enable/disable
for mod in surgeon.modules.values():
mod.disable()
with torch.no_grad():
disabled_diff = (model(**inputs).logits - baseline_logits).abs().mean().item()
print(f"Disabled vs baseline diff: {disabled_diff:.8f} (should be ~0)")
for mod in surgeon.modules.values():
mod.enable()
# Generation
gen_input = tokenizer("Once upon a time", return_tensors="pt").to(device)
with torch.no_grad():
gen_output = model.generate(**gen_input, max_new_tokens=50, do_sample=True, temperature=0.8, pad_token_id=tokenizer.pad_token_id)
print(f"Generated: {tokenizer.decode(gen_output[0], skip_special_tokens=True)[:200]}")
# Save
surgeon.save_cortex_modules("/tmp/cortex_modules.pt")
print(f"Saved: {os.path.getsize('/tmp/cortex_modules.pt') / 1024:.1f} KB")
print(f"\n{'='*60}")
print("ALL TESTS PASSED ✓")
print(f"{'='*60}")
if __name__ == "__main__":
main()
|