File size: 13,487 Bytes
3f42614 6bec07e 3f42614 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 | """
GLADIUS Plug β Cognitive adapter for external models.
Any model can rent GLADIUS's 170M cognitive parameters through a learned membrane.
The idea: a frozen LLM (GPT-2, Qwen, any VLM) produces hidden states.
Those hidden states project through a thin learned membrane into GLADIUS's
hidden dimension, then flow through the full GLADIUS layer stack β
depth cache, synthase gates, attention, memory β emerging as
cognitively enriched representations with a PUP uncertainty manifold.
Only the membrane learns. GLADIUS stays frozen. The mind stays the same.
The skin is swappable.
"There is no such thing as multi-modal." β Ali
Architecture:
External Model (frozen) β hidden_states [B, S, ext_dim]
β Membrane (learned) β [B, S, 640]
β GLADIUS Layers (frozen) β [B, S, 640]
β PUP Head (frozen) β uncertainty manifold (ΞΌ, ΟΒ², c)
The membrane is the only learned component: external_dim Γ 640 + 640 (LayerNorm).
For GPT-2 (768β640): 492,160 params. For Qwen-1.7B (2048β640): 1,312,000 params.
Everything else: frozen cognitive infrastructure.
Authors: Ali A. Shakil, Ava Shakil
Date: March 31, 2026
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from typing import Optional, Dict, Tuple
import dataclasses
class Membrane(nn.Module):
"""
Learned projection: external_dim β GLADIUS hidden_dim.
This is the only trainable component in a Plug setup.
It learns to translate another model's representation space
into GLADIUS's native cognitive dimension.
Architecture: Linear(ext_dim, gladius_dim) + LayerNorm(gladius_dim)
"""
def __init__(self, external_dim: int, gladius_dim: int = 640):
super().__init__()
self.proj = nn.Linear(external_dim, gladius_dim)
self.norm = nn.LayerNorm(gladius_dim)
self.external_dim = external_dim
self.gladius_dim = gladius_dim
self._init_weights()
def _init_weights(self):
"""Xavier init for smooth gradient flow at startup."""
nn.init.xavier_uniform_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [batch, seq_len, external_dim] from any external model
Returns:
[batch, seq_len, gladius_dim] ready for GLADIUS layer stack
"""
return self.norm(self.proj(x))
class GladiusPlug(nn.Module):
"""
Wraps a trained GLADIUS kernel as a frozen cognitive adapter.
The Plug loads a GLADIUS checkpoint, freezes it, and exposes its
transformer layer stack through a learned membrane. External models
produce hidden states β membrane projects to GLADIUS dim β layers
process with depth cache and attention β PUP reads uncertainty.
Usage:
plug = GladiusPlug("checkpoint.pt", external_dim=768)
enriched, pup_manifold = plug(gpt2_hidden_states)
# Only membrane trains
optimizer = torch.optim.Adam(plug.membrane_params(), lr=1e-4)
"""
def __init__(
self,
checkpoint_path: str,
external_dim: int,
freeze_gladius: bool = True,
device: str = 'cpu',
):
super().__init__()
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
# Load checkpoint
ckpt = torch.load(str(checkpoint_path), map_location=device, weights_only=False)
# Extract config β handle both dataclass and dict forms
config_raw = ckpt.get('config')
if config_raw is None:
raise ValueError("Checkpoint missing 'config' key")
if dataclasses.is_dataclass(config_raw) and not isinstance(config_raw, type):
config_dict = dataclasses.asdict(config_raw)
elif isinstance(config_raw, dict):
config_dict = config_raw
else:
config_dict = dict(config_raw)
# Build kernel from source (handles import path resolution)
kernel_src = Path(__file__).parent.parent
gladius_src = self._find_kernel_source(kernel_src)
import sys
if str(gladius_src) not in sys.path:
sys.path.insert(0, str(gladius_src))
from kernel import GladiusKernel
from kernel.config import KernelConfig
# Filter config to valid KernelConfig fields
valid_fields = {f.name for f in dataclasses.fields(KernelConfig)}
filtered = {k: v for k, v in config_dict.items() if k in valid_fields}
# Handle dtype serialization
if 'dtype' in filtered:
dtype_val = filtered['dtype']
if isinstance(dtype_val, str):
filtered['dtype'] = getattr(torch, dtype_val.replace('torch.', ''), torch.float32)
elif not isinstance(dtype_val, torch.dtype):
filtered['dtype'] = torch.float32
# Ensure cold_embedding_dim matches hidden_dim
if 'cold_embedding_dim' not in filtered or filtered.get('cold_embedding_dim') != filtered.get('hidden_dim'):
filtered['cold_embedding_dim'] = filtered.get('hidden_dim', 640)
config = KernelConfig(**filtered)
self.kernel = GladiusKernel(config)
# Load model weights (strict=False for optional components)
state_dict = ckpt.get('model_state_dict', ckpt.get('state_dict', {}))
self.kernel.load_state_dict(state_dict, strict=False)
# Apply synthase upgrade if checkpoint indicates it
self._has_synthase = bool(ckpt.get('synthase', False))
if self._has_synthase:
try:
from synthase.synthase_surgery import upgrade_to_synthase
upgrade_to_synthase(self.kernel)
# Reload weights to pick up synthase parameters
self.kernel.load_state_dict(state_dict, strict=False)
except ImportError:
print("Warning: Checkpoint has synthase but synthase_surgery not found. Skipping.")
self._has_synthase = False
# Apply PUP if checkpoint indicates it
self._has_pup = bool(ckpt.get('pup', False))
self.pup_head = None
if self._has_pup:
try:
from pup.pup_surgery import upgrade_kernel_to_pup
upgrade_kernel_to_pup(self.kernel)
self.pup_head = self.kernel.pup_head
# PUP weights are already in state_dict under pup_head.*
# They were loaded with the kernel load_state_dict above
except ImportError:
print("Warning: Checkpoint has PUP but pup_surgery not found. Skipping.")
self._has_pup = False
# Freeze GLADIUS kernel (the whole point)
if freeze_gladius:
for p in self.kernel.parameters():
p.requires_grad = False
self.kernel.eval()
# Extract dimensions from loaded config
self.gladius_dim = config.hidden_dim
self.num_layers = config.num_layers
self.max_seq_len = config.max_seq_len
self.config = config
self._step = ckpt.get('step', 0)
self._frozen = freeze_gladius
# Create membrane β the ONLY learned component
self.membrane = Membrane(external_dim, self.gladius_dim)
# Move to device
self.to(device)
self._report()
def _find_kernel_source(self, start: Path) -> Path:
"""
Find the GLADIUS kernel source directory.
Searches upward from plug/ for a directory containing kernel/kernel.py.
Falls back to gladius_v2/src/ if available.
"""
# Check if we're inside gladius_v2/staging/kernel/plug/
# -> parent = plug, parent.parent = kernel, parent.parent.parent = staging
# but the actual kernel.py with GladiusKernel is in gladius_v2/src/
# Strategy: walk up looking for src/kernel/kernel.py
current = start
for _ in range(6):
candidate = current / 'src'
if (candidate / 'kernel' / 'kernel.py').exists():
return str(candidate)
current = current.parent
# Fallback: check gladius_v2 relative to workspace
workspace = Path(os.environ.get('GLADIUS_WORKSPACE', '.'))
gladius_src = workspace / 'gladius_v2' / 'src'
if (gladius_src / 'kernel' / 'kernel.py').exists():
return str(gladius_src)
raise ImportError(
"Cannot find GLADIUS kernel source (kernel/kernel.py). "
"Expected in gladius_v2/src/ or parent directories of plug/."
)
def forward(
self,
external_hidden_states: torch.Tensor,
return_pup: bool = True,
) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
"""
Project external representations through the GLADIUS cognitive stack.
Args:
external_hidden_states: [batch, seq_len, external_dim]
Hidden states from any external model (GPT-2, Qwen, VLM, etc.)
return_pup: whether to compute PUP uncertainty manifold
Returns:
enriched: [batch, seq_len, gladius_dim] β depth-enriched representations
pup_manifold: dict with mu, sigma, confidence, log_var (or None)
"""
B, S, _ = external_hidden_states.shape
# Truncate to GLADIUS max sequence length
if S > self.max_seq_len:
external_hidden_states = external_hidden_states[:, :self.max_seq_len, :]
S = self.max_seq_len
# Project through membrane (external_dim β gladius_dim)
x = self.membrane(external_hidden_states)
# Run through GLADIUS transformer layers (bypassing embedding)
enriched = self._forward_through_layers(x)
# PUP uncertainty manifold
pup_manifold = None
if return_pup and self.pup_head is not None:
pup_manifold = self.pup_head(hidden=enriched)
return enriched, pup_manifold
def _forward_through_layers(self, x: torch.Tensor) -> torch.Tensor:
"""
Run through GLADIUS transformer layer stack, bypassing token embedding.
Handles both standard and synthase-upgraded layers.
Builds causal mask matching the kernel's expected format.
"""
B, S, D = x.shape
# Build causal mask (same format as GladiusKernel.forward)
if S <= self.max_seq_len and hasattr(self.kernel, 'causal_mask'):
mask = self.kernel.causal_mask[:, :, :S, :S]
else:
mask = torch.tril(torch.ones(1, 1, S, S, device=x.device))
# Run through each transformer layer
for layer in self.kernel.layers:
x = layer(x, mask=mask)
# Final norm
if hasattr(self.kernel, 'final_norm'):
x = self.kernel.final_norm(x)
return x
def membrane_params(self):
"""Return only membrane parameters (for optimizer)."""
return self.membrane.parameters()
def membrane_param_count(self) -> int:
"""Count of trainable membrane parameters."""
return sum(p.numel() for p in self.membrane.parameters())
def kernel_param_count(self) -> int:
"""Count of frozen kernel parameters."""
return sum(p.numel() for p in self.kernel.parameters())
def save_membrane(self, path: str):
"""Save only the membrane weights (tiny file)."""
torch.save({
'membrane_state_dict': self.membrane.state_dict(),
'external_dim': self.membrane.external_dim,
'gladius_dim': self.membrane.gladius_dim,
'kernel_step': self._step,
}, path)
print(f"Membrane saved: {path} ({self.membrane_param_count():,} params)")
def load_membrane(self, path: str):
"""Load membrane weights from file."""
data = torch.load(path, map_location='cpu')
state = data.get('membrane_state_dict', data)
self.membrane.load_state_dict(state)
print(f"Membrane loaded: {path}")
def _report(self):
"""Print Plug configuration summary."""
membrane_p = self.membrane_param_count()
kernel_p = self.kernel_param_count()
total_p = membrane_p + kernel_p
print(f"\n{'='*55}")
print(f" GLADIUS PLUG β Cognitive Adapter")
print(f"{'='*55}")
print(f" Kernel: {kernel_p:>12,} params (frozen={self._frozen})")
print(f" Membrane: {membrane_p:>12,} params (TRAINABLE)")
print(f" Total: {total_p:>12,} params")
print(f" Overhead: {membrane_p/kernel_p*100:.3f}%")
print(f" External dim: {self.membrane.external_dim}")
print(f" GLADIUS dim: {self.gladius_dim}")
print(f" Layers: {self.num_layers}")
print(f" Synthase: {'yes' if self._has_synthase else 'no'}")
print(f" PUP: {'yes' if self._has_pup else 'no'}")
print(f" From step: {self._step:,}")
print(f"{'='*55}\n")
|