""" Comprehensive test: Load a real pretrained LLM and inject all Cortex modules. Verify that: 1. All modules inject without errors 2. Forward pass works 3. Gradients flow only through Cortex parameters 4. Freshly injected modules preserve base outputs 5. Each module's specific functionality works Usage: python test_cortex.py """ import torch import sys import os from transformers import AutoModelForCausalLM, AutoTokenizer from cortex.torch_device import resolve_torch_device from cortex import ( CortexSurgeon, MemoryBank, HallucinationGate, PauseAndThink, BacktrackHead, SteeringVector, AdaptiveDepth ) import logging logging.basicConfig(level=logging.INFO, format="%(name)s | %(message)s") def main(): device = resolve_torch_device("auto") print(f"\n{'='*60}") print(f"CORTEX TEST — Device: {device}") print(f"{'='*60}\n") model_name = "HuggingFaceTB/SmolLM2-135M" print(f"Loading model: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float32, device_map=device) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token hidden_dim = model.config.hidden_size num_layers = model.config.num_hidden_layers print(f"Model: hidden_dim={hidden_dim}, num_layers={num_layers}") # Baseline test_input = "The capital of France is" inputs = tokenizer(test_input, return_tensors="pt").to(device) with torch.no_grad(): baseline_logits = model(**inputs).logits.clone() print(f"Baseline next token: '{tokenizer.decode(baseline_logits[0, -1].argmax())}'") # Surgery surgeon = CortexSurgeon(model) num_heads = 4 while hidden_dim % num_heads != 0: num_heads -= 1 middle_layers = list(range(num_layers // 3, 2 * num_layers // 3)) deep_layers = list(range(2 * num_layers // 3, num_layers)) surgeon.add_module("memory", MemoryBank(hidden_dim=hidden_dim, num_slots=16, num_heads=num_heads, target_layers=middle_layers)) surgeon.add_module("halluc_gate", HallucinationGate(hidden_dim=hidden_dim, bottleneck_dim=32, target_layers=deep_layers)) surgeon.add_module("pause_think", PauseAndThink(hidden_dim=hidden_dim, num_think_tokens=4, target_layers=middle_layers)) surgeon.add_module("backtrack", BacktrackHead(hidden_dim=hidden_dim, confidence_bottleneck=32, num_layers=num_layers, target_layers="all")) surgeon.add_module("steering", SteeringVector(hidden_dim=hidden_dim, num_directions=2, direction_names=["truthfulness", "helpfulness"], target_layers=middle_layers)) surgeon.add_module("adaptive_depth", AdaptiveDepth(hidden_dim=hidden_dim, target_layers="all")) surgeon.operate(freeze_base=True) report = surgeon.get_parameter_report() total_cortex = sum(info['trainable'] for info in report.values()) total_model = sum(p.numel() for p in model.parameters()) print(f"\nCortex: {total_cortex:,} params ({total_cortex/total_model*100:.2f}% overhead)") # Forward with modules with torch.no_grad(): enhanced_logits = model(**inputs).logits.clone() logit_diff = (enhanced_logits - baseline_logits).abs().mean().item() print(f"Logit diff with modules: {logit_diff:.6f}") # Gradient check output = model(**inputs, labels=inputs["input_ids"].clone()) output.loss.backward() cortex_grads = sum(1 for p in surgeon.get_trainable_parameters() if p.grad is not None and p.grad.abs().sum() > 0) print(f"Cortex params with gradients: {cortex_grads}") # Enable/disable for mod in surgeon.modules.values(): mod.disable() with torch.no_grad(): disabled_diff = (model(**inputs).logits - baseline_logits).abs().mean().item() print(f"Disabled vs baseline diff: {disabled_diff:.8f} (should be ~0)") for mod in surgeon.modules.values(): mod.enable() # Generation gen_input = tokenizer("Once upon a time", return_tensors="pt").to(device) with torch.no_grad(): gen_output = model.generate(**gen_input, max_new_tokens=50, do_sample=True, temperature=0.8, pad_token_id=tokenizer.pad_token_id) print(f"Generated: {tokenizer.decode(gen_output[0], skip_special_tokens=True)[:200]}") # Save surgeon.save_cortex_modules("/tmp/cortex_modules.pt") print(f"Saved: {os.path.getsize('/tmp/cortex_modules.pt') / 1024:.1f} KB") print(f"\n{'='*60}") print("ALL TESTS PASSED ✓") print(f"{'='*60}") if __name__ == "__main__": main()