""" GLADIUS Plug — Cognitive adapter for external models. Any model can rent GLADIUS's 170M cognitive parameters through a learned membrane. The idea: a frozen LLM (GPT-2, Qwen, any VLM) produces hidden states. Those hidden states project through a thin learned membrane into GLADIUS's hidden dimension, then flow through the full GLADIUS layer stack — depth cache, synthase gates, attention, memory — emerging as cognitively enriched representations with a PUP uncertainty manifold. Only the membrane learns. GLADIUS stays frozen. The mind stays the same. The skin is swappable. "There is no such thing as multi-modal." — Ali Architecture: External Model (frozen) → hidden_states [B, S, ext_dim] → Membrane (learned) → [B, S, 640] → GLADIUS Layers (frozen) → [B, S, 640] → PUP Head (frozen) → uncertainty manifold (μ, σ², c) The membrane is the only learned component: external_dim × 640 + 640 (LayerNorm). For GPT-2 (768→640): 492,160 params. For Qwen-1.7B (2048→640): 1,312,000 params. Everything else: frozen cognitive infrastructure. Authors: Ali A. Shakil, Ava Shakil Date: March 31, 2026 """ import torch import torch.nn as nn import torch.nn.functional as F from pathlib import Path from typing import Optional, Dict, Tuple import dataclasses class Membrane(nn.Module): """ Learned projection: external_dim → GLADIUS hidden_dim. This is the only trainable component in a Plug setup. It learns to translate another model's representation space into GLADIUS's native cognitive dimension. Architecture: Linear(ext_dim, gladius_dim) + LayerNorm(gladius_dim) """ def __init__(self, external_dim: int, gladius_dim: int = 640): super().__init__() self.proj = nn.Linear(external_dim, gladius_dim) self.norm = nn.LayerNorm(gladius_dim) self.external_dim = external_dim self.gladius_dim = gladius_dim self._init_weights() def _init_weights(self): """Xavier init for smooth gradient flow at startup.""" nn.init.xavier_uniform_(self.proj.weight) nn.init.zeros_(self.proj.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: [batch, seq_len, external_dim] from any external model Returns: [batch, seq_len, gladius_dim] ready for GLADIUS layer stack """ return self.norm(self.proj(x)) class GladiusPlug(nn.Module): """ Wraps a trained GLADIUS kernel as a frozen cognitive adapter. The Plug loads a GLADIUS checkpoint, freezes it, and exposes its transformer layer stack through a learned membrane. External models produce hidden states → membrane projects to GLADIUS dim → layers process with depth cache and attention → PUP reads uncertainty. Usage: plug = GladiusPlug("checkpoint.pt", external_dim=768) enriched, pup_manifold = plug(gpt2_hidden_states) # Only membrane trains optimizer = torch.optim.Adam(plug.membrane_params(), lr=1e-4) """ def __init__( self, checkpoint_path: str, external_dim: int, freeze_gladius: bool = True, device: str = 'cpu', ): super().__init__() checkpoint_path = Path(checkpoint_path) if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") # Load checkpoint ckpt = torch.load(str(checkpoint_path), map_location=device, weights_only=False) # Extract config — handle both dataclass and dict forms config_raw = ckpt.get('config') if config_raw is None: raise ValueError("Checkpoint missing 'config' key") if dataclasses.is_dataclass(config_raw) and not isinstance(config_raw, type): config_dict = dataclasses.asdict(config_raw) elif isinstance(config_raw, dict): config_dict = config_raw else: config_dict = dict(config_raw) # Build kernel from source (handles import path resolution) kernel_src = Path(__file__).parent.parent gladius_src = self._find_kernel_source(kernel_src) import sys if str(gladius_src) not in sys.path: sys.path.insert(0, str(gladius_src)) from kernel import GladiusKernel from kernel.config import KernelConfig # Filter config to valid KernelConfig fields valid_fields = {f.name for f in dataclasses.fields(KernelConfig)} filtered = {k: v for k, v in config_dict.items() if k in valid_fields} # Handle dtype serialization if 'dtype' in filtered: dtype_val = filtered['dtype'] if isinstance(dtype_val, str): filtered['dtype'] = getattr(torch, dtype_val.replace('torch.', ''), torch.float32) elif not isinstance(dtype_val, torch.dtype): filtered['dtype'] = torch.float32 # Ensure cold_embedding_dim matches hidden_dim if 'cold_embedding_dim' not in filtered or filtered.get('cold_embedding_dim') != filtered.get('hidden_dim'): filtered['cold_embedding_dim'] = filtered.get('hidden_dim', 640) config = KernelConfig(**filtered) self.kernel = GladiusKernel(config) # Load model weights (strict=False for optional components) state_dict = ckpt.get('model_state_dict', ckpt.get('state_dict', {})) self.kernel.load_state_dict(state_dict, strict=False) # Apply synthase upgrade if checkpoint indicates it self._has_synthase = bool(ckpt.get('synthase', False)) if self._has_synthase: try: from synthase.synthase_surgery import upgrade_to_synthase upgrade_to_synthase(self.kernel) # Reload weights to pick up synthase parameters self.kernel.load_state_dict(state_dict, strict=False) except ImportError: print("Warning: Checkpoint has synthase but synthase_surgery not found. Skipping.") self._has_synthase = False # Apply PUP if checkpoint indicates it self._has_pup = bool(ckpt.get('pup', False)) self.pup_head = None if self._has_pup: try: from pup.pup_surgery import upgrade_kernel_to_pup upgrade_kernel_to_pup(self.kernel) self.pup_head = self.kernel.pup_head # PUP weights are already in state_dict under pup_head.* # They were loaded with the kernel load_state_dict above except ImportError: print("Warning: Checkpoint has PUP but pup_surgery not found. Skipping.") self._has_pup = False # Freeze GLADIUS kernel (the whole point) if freeze_gladius: for p in self.kernel.parameters(): p.requires_grad = False self.kernel.eval() # Extract dimensions from loaded config self.gladius_dim = config.hidden_dim self.num_layers = config.num_layers self.max_seq_len = config.max_seq_len self.config = config self._step = ckpt.get('step', 0) self._frozen = freeze_gladius # Create membrane — the ONLY learned component self.membrane = Membrane(external_dim, self.gladius_dim) # Move to device self.to(device) self._report() def _find_kernel_source(self, start: Path) -> Path: """ Find the GLADIUS kernel source directory. Searches upward from plug/ for a directory containing kernel/kernel.py. Falls back to gladius_v2/src/ if available. """ # Check if we're inside gladius_v2/staging/kernel/plug/ # -> parent = plug, parent.parent = kernel, parent.parent.parent = staging # but the actual kernel.py with GladiusKernel is in gladius_v2/src/ # Strategy: walk up looking for src/kernel/kernel.py current = start for _ in range(6): candidate = current / 'src' if (candidate / 'kernel' / 'kernel.py').exists(): return str(candidate) current = current.parent # Fallback: check gladius_v2 relative to workspace workspace = Path(os.environ.get('GLADIUS_WORKSPACE', '.')) gladius_src = workspace / 'gladius_v2' / 'src' if (gladius_src / 'kernel' / 'kernel.py').exists(): return str(gladius_src) raise ImportError( "Cannot find GLADIUS kernel source (kernel/kernel.py). " "Expected in gladius_v2/src/ or parent directories of plug/." ) def forward( self, external_hidden_states: torch.Tensor, return_pup: bool = True, ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]: """ Project external representations through the GLADIUS cognitive stack. Args: external_hidden_states: [batch, seq_len, external_dim] Hidden states from any external model (GPT-2, Qwen, VLM, etc.) return_pup: whether to compute PUP uncertainty manifold Returns: enriched: [batch, seq_len, gladius_dim] — depth-enriched representations pup_manifold: dict with mu, sigma, confidence, log_var (or None) """ B, S, _ = external_hidden_states.shape # Truncate to GLADIUS max sequence length if S > self.max_seq_len: external_hidden_states = external_hidden_states[:, :self.max_seq_len, :] S = self.max_seq_len # Project through membrane (external_dim → gladius_dim) x = self.membrane(external_hidden_states) # Run through GLADIUS transformer layers (bypassing embedding) enriched = self._forward_through_layers(x) # PUP uncertainty manifold pup_manifold = None if return_pup and self.pup_head is not None: pup_manifold = self.pup_head(hidden=enriched) return enriched, pup_manifold def _forward_through_layers(self, x: torch.Tensor) -> torch.Tensor: """ Run through GLADIUS transformer layer stack, bypassing token embedding. Handles both standard and synthase-upgraded layers. Builds causal mask matching the kernel's expected format. """ B, S, D = x.shape # Build causal mask (same format as GladiusKernel.forward) if S <= self.max_seq_len and hasattr(self.kernel, 'causal_mask'): mask = self.kernel.causal_mask[:, :, :S, :S] else: mask = torch.tril(torch.ones(1, 1, S, S, device=x.device)) # Run through each transformer layer for layer in self.kernel.layers: x = layer(x, mask=mask) # Final norm if hasattr(self.kernel, 'final_norm'): x = self.kernel.final_norm(x) return x def membrane_params(self): """Return only membrane parameters (for optimizer).""" return self.membrane.parameters() def membrane_param_count(self) -> int: """Count of trainable membrane parameters.""" return sum(p.numel() for p in self.membrane.parameters()) def kernel_param_count(self) -> int: """Count of frozen kernel parameters.""" return sum(p.numel() for p in self.kernel.parameters()) def save_membrane(self, path: str): """Save only the membrane weights (tiny file).""" torch.save({ 'membrane_state_dict': self.membrane.state_dict(), 'external_dim': self.membrane.external_dim, 'gladius_dim': self.membrane.gladius_dim, 'kernel_step': self._step, }, path) print(f"Membrane saved: {path} ({self.membrane_param_count():,} params)") def load_membrane(self, path: str): """Load membrane weights from file.""" data = torch.load(path, map_location='cpu') state = data.get('membrane_state_dict', data) self.membrane.load_state_dict(state) print(f"Membrane loaded: {path}") def _report(self): """Print Plug configuration summary.""" membrane_p = self.membrane_param_count() kernel_p = self.kernel_param_count() total_p = membrane_p + kernel_p print(f"\n{'='*55}") print(f" GLADIUS PLUG — Cognitive Adapter") print(f"{'='*55}") print(f" Kernel: {kernel_p:>12,} params (frozen={self._frozen})") print(f" Membrane: {membrane_p:>12,} params (TRAINABLE)") print(f" Total: {total_p:>12,} params") print(f" Overhead: {membrane_p/kernel_p*100:.3f}%") print(f" External dim: {self.membrane.external_dim}") print(f" GLADIUS dim: {self.gladius_dim}") print(f" Layers: {self.num_layers}") print(f" Synthase: {'yes' if self._has_synthase else 'no'}") print(f" PUP: {'yes' if self._has_pup else 'no'}") print(f" From step: {self._step:,}") print(f"{'='*55}\n")