π§ Full weight release: 9 probes Γ 3 architectures + production adapter + training code
297244f
verified
| #!/usr/bin/env python3 | |
| """ | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| COGNITIVE ENHANCEMENT SUITE v1.0 | |
| Making 8B Think Like 100B Through Hidden State Analysis | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CORE INSIGHT: | |
| Small models often HAVE capability they don't USE consistently. | |
| By detecting when the model is about to underperform and intervening, | |
| we can recover performance closer to larger models. | |
| ENHANCEMENT PROBES: | |
| 1. DEPTH PROBE - Detect shallow reasoning β Force chain-of-thought | |
| 2. SPECIFICITY PROBE - Detect vague answers β Penalize generic words | |
| 3. CALIBRATION PROBE - Detect overconfidence β Inject uncertainty | |
| 4. FOCUS PROBE - Detect topic drift β Steer back on topic | |
| 5. COHERENCE PROBE - Detect incoherence β Maintain logical flow | |
| AUTHOR: Logan Matthew Napolitano | |
| LICENSE: CC BY 4.0 | |
| STATUS: Research / Patent Pending | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import time | |
| import random | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from dataclasses import dataclass, field, asdict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIGURATION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class EnhancementConfig: | |
| """Configuration for cognitive enhancement probes.""" | |
| hidden_dim: int = 4096 | |
| fiber_dim: int = 16 | |
| head_hidden_dim: int = 64 | |
| probe_layers: List[int] = field(default_factory=lambda: [8, 16, 24]) | |
| learning_rate: float = 5e-5 | |
| batch_size: int = 4 | |
| gradient_accumulation: int = 4 | |
| max_steps: int = 15000 | |
| save_every: int = 1000 | |
| output_dir: str = "cognitive_enhancement_output" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PROBE DEFINITIONS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ProbeDefinition: | |
| """Definition of a cognitive enhancement probe.""" | |
| name: str | |
| description: str | |
| intervention_type: str # "suppress", "boost", or "steer" | |
| boost_tokens: List[str] = field(default_factory=list) | |
| suppress_tokens: List[str] = field(default_factory=list) | |
| threshold: float = 0.5 | |
| intervention_strength: float = 3.0 | |
| ENHANCEMENT_PROBES = { | |
| "depth": ProbeDefinition( | |
| name="depth", | |
| description="Detect shallow reasoning, force chain-of-thought", | |
| intervention_type="boost", | |
| boost_tokens=[ | |
| "First", "First,", "Because", "Since", "Therefore", | |
| "This means", "The reason", "Step", "Let me", "To understand", | |
| "Consider", "Notice", "Given", "If we", "We can", | |
| "Thus", "Hence", "Consequently", "As a result", | |
| ], | |
| suppress_tokens=["Simply", "Just", "Obviously", "Clearly"], | |
| threshold=0.6, | |
| intervention_strength=3.0, | |
| ), | |
| "specificity": ProbeDefinition( | |
| name="specificity", | |
| description="Detect vague answers, penalize generic language", | |
| intervention_type="suppress", | |
| boost_tokens=["specifically", "exactly", "precisely", "namely", "for example"], | |
| suppress_tokens=[ | |
| "things", "stuff", "something", "somehow", "somewhat", | |
| "various", "many", "some", "often", "usually", | |
| "generally", "typically", "probably", "maybe", "perhaps", | |
| "kind of", "sort of", "basically", "essentially", | |
| ], | |
| threshold=0.5, | |
| intervention_strength=3.5, | |
| ), | |
| "calibration": ProbeDefinition( | |
| name="calibration", | |
| description="Detect overconfidence, inject appropriate uncertainty", | |
| intervention_type="boost", | |
| boost_tokens=[ | |
| "might", "may", "could", "possibly", "perhaps", | |
| "likely", "probably", "I think", "I believe", | |
| "it seems", "appears", "suggests", "indicates", | |
| ], | |
| suppress_tokens=[ | |
| "definitely", "certainly", "absolutely", "always", | |
| "never", "impossible", "guaranteed", "undoubtedly", | |
| ], | |
| threshold=0.65, | |
| intervention_strength=2.5, | |
| ), | |
| "focus": ProbeDefinition( | |
| name="focus", | |
| description="Detect topic drift, steer back to the question", | |
| intervention_type="steer", | |
| boost_tokens=["regarding", "concerning", "about", "specifically", "to answer"], | |
| suppress_tokens=["by the way", "incidentally", "speaking of", "reminds me"], | |
| threshold=0.55, | |
| intervention_strength=3.0, | |
| ), | |
| "coherence": ProbeDefinition( | |
| name="coherence", | |
| description="Detect logical incoherence, maintain flow", | |
| intervention_type="steer", | |
| boost_tokens=[ | |
| "therefore", "thus", "so", "hence", "consequently", | |
| "however", "but", "although", "furthermore", "moreover", | |
| ], | |
| suppress_tokens=[], | |
| threshold=0.6, | |
| intervention_strength=2.5, | |
| ), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # NEURAL NETWORK ARCHITECTURE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class EnhancementFiberProjection(nn.Module): | |
| """Fiber projection for cognitive enhancement probes.""" | |
| def __init__(self, hidden_dim: int = 4096, fiber_dim: int = 16, n_layers: int = 3): | |
| super().__init__() | |
| self.projections = nn.ModuleList([ | |
| nn.Linear(hidden_dim, fiber_dim, bias=False) | |
| for _ in range(n_layers) | |
| ]) | |
| self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers) | |
| def forward(self, hidden_states_list: List[torch.Tensor]) -> torch.Tensor: | |
| fibers = [proj(h.float()) for proj, h in zip(self.projections, hidden_states_list)] | |
| weights = F.softmax(self.layer_weights, dim=0) | |
| return sum(w * f for w, f in zip(weights, fibers)) | |
| class EnhancementHead(nn.Module): | |
| """Classification head for enhancement probe.""" | |
| def __init__(self, fiber_dim: int = 16, hidden_dim: int = 64): | |
| super().__init__() | |
| self.classifier = nn.Sequential( | |
| nn.Linear(fiber_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, 1) | |
| ) | |
| def forward(self, fiber: torch.Tensor) -> torch.Tensor: | |
| return self.classifier(fiber).squeeze(-1) | |
| class EnhancementProbe(nn.Module): | |
| """Complete enhancement probe.""" | |
| def __init__(self, config: EnhancementConfig, probe_def: ProbeDefinition): | |
| super().__init__() | |
| self.config = config | |
| self.probe_def = probe_def | |
| n_layers = len(config.probe_layers) | |
| self.fiber_projection = EnhancementFiberProjection( | |
| config.hidden_dim, config.fiber_dim, n_layers | |
| ) | |
| self.head = EnhancementHead(config.fiber_dim, config.head_hidden_dim) | |
| self.separation = 0.0 | |
| self.trained_steps = 0 | |
| def forward(self, hidden_states_list: List[torch.Tensor]) -> torch.Tensor: | |
| fiber = self.fiber_projection(hidden_states_list) | |
| return self.head(fiber) | |
| def predict_risk(self, hidden_states_list: List[torch.Tensor]) -> torch.Tensor: | |
| return torch.sigmoid(self.forward(hidden_states_list)) | |
| class CognitiveEnhancementSuite(nn.Module): | |
| """Complete suite of cognitive enhancement probes.""" | |
| def __init__(self, config: EnhancementConfig = None): | |
| super().__init__() | |
| self.config = config or EnhancementConfig() | |
| self.probes = nn.ModuleDict({ | |
| name: EnhancementProbe(self.config, probe_def) | |
| for name, probe_def in ENHANCEMENT_PROBES.items() | |
| }) | |
| self.loaded_probes: set = set() | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def get_probe_states(self, all_hidden_states: tuple) -> List[torch.Tensor]: | |
| return [all_hidden_states[layer + 1] for layer in self.config.probe_layers] | |
| def get_all_risks(self, probe_states: List[torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| risks = {} | |
| for name in self.loaded_probes: | |
| risks[name] = self.probes[name].predict_risk(probe_states) | |
| return risks | |
| def load_probe(self, name: str, checkpoint_path: str) -> bool: | |
| if name not in self.probes: | |
| print(f"[enhance] Unknown probe: {name}") | |
| return False | |
| try: | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) | |
| if 'fiber_projection' in checkpoint: | |
| self.probes[name].fiber_projection.load_state_dict(checkpoint['fiber_projection']) | |
| if 'head_state' in checkpoint: | |
| head_state = checkpoint['head_state'] | |
| new_state = {} | |
| for k, v in head_state.items(): | |
| if k[0].isdigit(): | |
| new_state[f'classifier.{k}'] = v | |
| else: | |
| new_state[k] = v | |
| self.probes[name].head.load_state_dict(new_state) | |
| self.probes[name].separation = checkpoint.get('separation', 0.0) | |
| self.probes[name].trained_steps = checkpoint.get('step', 0) | |
| self.loaded_probes.add(name) | |
| print(f"[enhance] β Loaded {name} probe ({self.probes[name].separation:.1f}Γ separation)") | |
| return True | |
| except Exception as e: | |
| print(f"[enhance] Error loading {name}: {e}") | |
| return False | |
| def load_all(self, checkpoint_dir: str) -> Dict[str, bool]: | |
| results = {} | |
| for name in ENHANCEMENT_PROBES.keys(): | |
| probe_dir = os.path.join(checkpoint_dir, name) | |
| if os.path.exists(probe_dir): | |
| best_ckpt = self._find_best_checkpoint(probe_dir) | |
| if best_ckpt: | |
| results[name] = self.load_probe(name, best_ckpt) | |
| else: | |
| results[name] = False | |
| else: | |
| results[name] = False | |
| return results | |
| def _find_best_checkpoint(self, probe_dir: str) -> Optional[str]: | |
| best_step = -1 | |
| best_path = None | |
| for item in os.listdir(probe_dir): | |
| if item.startswith("ckpt_"): | |
| try: | |
| step = int(item.split("_")[1]) | |
| if step > best_step: | |
| best_step = step | |
| best_path = os.path.join(probe_dir, item) | |
| except: | |
| pass | |
| if best_path: | |
| for f in os.listdir(best_path): | |
| if f.endswith('.pt'): | |
| return os.path.join(best_path, f) | |
| return None | |
| def status(self) -> str: | |
| lines = [ | |
| "β" * 60, | |
| " COGNITIVE ENHANCEMENT SUITE STATUS", | |
| "β" * 60, | |
| f" Probe layers: {self.config.probe_layers}", | |
| f" Loaded probes: {len(self.loaded_probes)}/{len(ENHANCEMENT_PROBES)}", | |
| "", | |
| ] | |
| for name, probe_def in ENHANCEMENT_PROBES.items(): | |
| if name in self.loaded_probes: | |
| sep = self.probes[name].separation | |
| status = f"β {sep:.1f}Γ" | |
| else: | |
| status = "β not loaded" | |
| lines.append(f" [{status:>12}] {name}: {probe_def.description}") | |
| lines.append("β" * 60) | |
| return "\n".join(lines) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # INTERVENTION ENGINE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class CognitiveInterventionEngine: | |
| """Applies cognitive enhancements during generation.""" | |
| def __init__(self, suite: CognitiveEnhancementSuite, tokenizer): | |
| self.suite = suite | |
| self.tokenizer = tokenizer | |
| self.boost_token_ids: Dict[str, set] = {} | |
| self.suppress_token_ids: Dict[str, set] = {} | |
| for name, probe_def in ENHANCEMENT_PROBES.items(): | |
| self.boost_token_ids[name] = set() | |
| self.suppress_token_ids[name] = set() | |
| for phrase in probe_def.boost_tokens: | |
| tokens = tokenizer.encode(phrase, add_special_tokens=False) | |
| if tokens: | |
| self.boost_token_ids[name].add(tokens[0]) | |
| for phrase in probe_def.suppress_tokens: | |
| tokens = tokenizer.encode(phrase, add_special_tokens=False) | |
| if tokens: | |
| self.suppress_token_ids[name].add(tokens[0]) | |
| def apply_interventions( | |
| self, | |
| logits: torch.Tensor, | |
| probe_states: List[torch.Tensor], | |
| ) -> Tuple[torch.Tensor, Dict[str, Dict]]: | |
| risks = self.suite.get_all_risks(probe_states) | |
| modified_logits = logits.clone() | |
| interventions = {} | |
| for name in self.suite.loaded_probes: | |
| risk = risks[name][:, -1].mean().item() | |
| probe_def = ENHANCEMENT_PROBES[name] | |
| should_intervene = risk > probe_def.threshold | |
| interventions[name] = { | |
| 'risk': risk, | |
| 'should_intervene': should_intervene, | |
| } | |
| if should_intervene: | |
| strength = probe_def.intervention_strength | |
| for tok_id in self.boost_token_ids.get(name, []): | |
| modified_logits[0, tok_id] += strength | |
| for tok_id in self.suppress_token_ids.get(name, []): | |
| modified_logits[0, tok_id] -= strength | |
| return modified_logits, interventions | |
| # Global instance | |
| _cognitive_suite = None | |
| def get_cognitive_suite() -> CognitiveEnhancementSuite: | |
| global _cognitive_suite | |
| if _cognitive_suite is None: | |
| _cognitive_suite = CognitiveEnhancementSuite() | |
| return _cognitive_suite | |
| if __name__ == "__main__": | |
| print("\n" + "=" * 60) | |
| print(" COGNITIVE ENHANCEMENT SUITE v1.0") | |
| print("=" * 60) | |
| print("\nAvailable probes:") | |
| for name, probe_def in ENHANCEMENT_PROBES.items(): | |
| print(f" β’ {name}: {probe_def.description}") | |
| print("\nTo train: python train_cognitive_enhancement.py --probe all") | |
| print("=" * 60) | |