File size: 4,612 Bytes
ba8ab4a
 
 
 
 
 
0de2901
ba8ab4a
 
 
 
 
 
 
 
 
 
 
0ac64e3
ba8ab4a
 
 
 
 
 
 
 
 
 
 
 
 
0ac64e3
ba8ab4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0de2901
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Comprehensive test: Load a real pretrained LLM and inject all Cortex modules.
Verify that:
1. All modules inject without errors
2. Forward pass works
3. Gradients flow only through Cortex parameters
4. Freshly injected modules preserve base outputs
5. Each module's specific functionality works

Usage:
    python test_cortex.py
"""

import torch
import sys
import os

from transformers import AutoModelForCausalLM, AutoTokenizer
from cortex.torch_device import resolve_torch_device
from cortex import (
    CortexSurgeon, 
    MemoryBank, 
    HallucinationGate, 
    PauseAndThink, 
    BacktrackHead, 
    SteeringVector, 
    AdaptiveDepth
)
import logging
logging.basicConfig(level=logging.INFO, format="%(name)s | %(message)s")

def main():
    device = resolve_torch_device("auto")
    print(f"\n{'='*60}")
    print(f"CORTEX TEST — Device: {device}")
    print(f"{'='*60}\n")
    
    model_name = "HuggingFaceTB/SmolLM2-135M"
    print(f"Loading model: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float32, device_map=device)
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    hidden_dim = model.config.hidden_size
    num_layers = model.config.num_hidden_layers
    print(f"Model: hidden_dim={hidden_dim}, num_layers={num_layers}")
    
    # Baseline
    test_input = "The capital of France is"
    inputs = tokenizer(test_input, return_tensors="pt").to(device)
    with torch.no_grad():
        baseline_logits = model(**inputs).logits.clone()
    print(f"Baseline next token: '{tokenizer.decode(baseline_logits[0, -1].argmax())}'")
    
    # Surgery
    surgeon = CortexSurgeon(model)
    num_heads = 4
    while hidden_dim % num_heads != 0:
        num_heads -= 1
    
    middle_layers = list(range(num_layers // 3, 2 * num_layers // 3))
    deep_layers = list(range(2 * num_layers // 3, num_layers))
    
    surgeon.add_module("memory", MemoryBank(hidden_dim=hidden_dim, num_slots=16, num_heads=num_heads, target_layers=middle_layers))
    surgeon.add_module("halluc_gate", HallucinationGate(hidden_dim=hidden_dim, bottleneck_dim=32, target_layers=deep_layers))
    surgeon.add_module("pause_think", PauseAndThink(hidden_dim=hidden_dim, num_think_tokens=4, target_layers=middle_layers))
    surgeon.add_module("backtrack", BacktrackHead(hidden_dim=hidden_dim, confidence_bottleneck=32, num_layers=num_layers, target_layers="all"))
    surgeon.add_module("steering", SteeringVector(hidden_dim=hidden_dim, num_directions=2, direction_names=["truthfulness", "helpfulness"], target_layers=middle_layers))
    surgeon.add_module("adaptive_depth", AdaptiveDepth(hidden_dim=hidden_dim, target_layers="all"))
    
    surgeon.operate(freeze_base=True)
    
    report = surgeon.get_parameter_report()
    total_cortex = sum(info['trainable'] for info in report.values())
    total_model = sum(p.numel() for p in model.parameters())
    print(f"\nCortex: {total_cortex:,} params ({total_cortex/total_model*100:.2f}% overhead)")
    
    # Forward with modules
    with torch.no_grad():
        enhanced_logits = model(**inputs).logits.clone()
    logit_diff = (enhanced_logits - baseline_logits).abs().mean().item()
    print(f"Logit diff with modules: {logit_diff:.6f}")
    
    # Gradient check
    output = model(**inputs, labels=inputs["input_ids"].clone())
    output.loss.backward()
    cortex_grads = sum(1 for p in surgeon.get_trainable_parameters() if p.grad is not None and p.grad.abs().sum() > 0)
    print(f"Cortex params with gradients: {cortex_grads}")
    
    # Enable/disable
    for mod in surgeon.modules.values():
        mod.disable()
    with torch.no_grad():
        disabled_diff = (model(**inputs).logits - baseline_logits).abs().mean().item()
    print(f"Disabled vs baseline diff: {disabled_diff:.8f} (should be ~0)")
    
    for mod in surgeon.modules.values():
        mod.enable()
    
    # Generation
    gen_input = tokenizer("Once upon a time", return_tensors="pt").to(device)
    with torch.no_grad():
        gen_output = model.generate(**gen_input, max_new_tokens=50, do_sample=True, temperature=0.8, pad_token_id=tokenizer.pad_token_id)
    print(f"Generated: {tokenizer.decode(gen_output[0], skip_special_tokens=True)[:200]}")
    
    # Save
    surgeon.save_cortex_modules("/tmp/cortex_modules.pt")
    print(f"Saved: {os.path.getsize('/tmp/cortex_modules.pt') / 1024:.1f} KB")
    
    print(f"\n{'='*60}")
    print("ALL TESTS PASSED ✓")
    print(f"{'='*60}")

if __name__ == "__main__":
    main()