Gladius / extensions /plug /plug.py
ava-shakil's picture
Clean repo: remove non-research content, fix paths, update docs
6bec07e
"""
GLADIUS Plug — Cognitive adapter for external models.
Any model can rent GLADIUS's 170M cognitive parameters through a learned membrane.
The idea: a frozen LLM (GPT-2, Qwen, any VLM) produces hidden states.
Those hidden states project through a thin learned membrane into GLADIUS's
hidden dimension, then flow through the full GLADIUS layer stack —
depth cache, synthase gates, attention, memory — emerging as
cognitively enriched representations with a PUP uncertainty manifold.
Only the membrane learns. GLADIUS stays frozen. The mind stays the same.
The skin is swappable.
"There is no such thing as multi-modal." — Ali
Architecture:
External Model (frozen) → hidden_states [B, S, ext_dim]
→ Membrane (learned) → [B, S, 640]
→ GLADIUS Layers (frozen) → [B, S, 640]
→ PUP Head (frozen) → uncertainty manifold (μ, σ², c)
The membrane is the only learned component: external_dim × 640 + 640 (LayerNorm).
For GPT-2 (768→640): 492,160 params. For Qwen-1.7B (2048→640): 1,312,000 params.
Everything else: frozen cognitive infrastructure.
Authors: Ali A. Shakil, Ava Shakil
Date: March 31, 2026
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from typing import Optional, Dict, Tuple
import dataclasses
class Membrane(nn.Module):
"""
Learned projection: external_dim → GLADIUS hidden_dim.
This is the only trainable component in a Plug setup.
It learns to translate another model's representation space
into GLADIUS's native cognitive dimension.
Architecture: Linear(ext_dim, gladius_dim) + LayerNorm(gladius_dim)
"""
def __init__(self, external_dim: int, gladius_dim: int = 640):
super().__init__()
self.proj = nn.Linear(external_dim, gladius_dim)
self.norm = nn.LayerNorm(gladius_dim)
self.external_dim = external_dim
self.gladius_dim = gladius_dim
self._init_weights()
def _init_weights(self):
"""Xavier init for smooth gradient flow at startup."""
nn.init.xavier_uniform_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [batch, seq_len, external_dim] from any external model
Returns:
[batch, seq_len, gladius_dim] ready for GLADIUS layer stack
"""
return self.norm(self.proj(x))
class GladiusPlug(nn.Module):
"""
Wraps a trained GLADIUS kernel as a frozen cognitive adapter.
The Plug loads a GLADIUS checkpoint, freezes it, and exposes its
transformer layer stack through a learned membrane. External models
produce hidden states → membrane projects to GLADIUS dim → layers
process with depth cache and attention → PUP reads uncertainty.
Usage:
plug = GladiusPlug("checkpoint.pt", external_dim=768)
enriched, pup_manifold = plug(gpt2_hidden_states)
# Only membrane trains
optimizer = torch.optim.Adam(plug.membrane_params(), lr=1e-4)
"""
def __init__(
self,
checkpoint_path: str,
external_dim: int,
freeze_gladius: bool = True,
device: str = 'cpu',
):
super().__init__()
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
# Load checkpoint
ckpt = torch.load(str(checkpoint_path), map_location=device, weights_only=False)
# Extract config — handle both dataclass and dict forms
config_raw = ckpt.get('config')
if config_raw is None:
raise ValueError("Checkpoint missing 'config' key")
if dataclasses.is_dataclass(config_raw) and not isinstance(config_raw, type):
config_dict = dataclasses.asdict(config_raw)
elif isinstance(config_raw, dict):
config_dict = config_raw
else:
config_dict = dict(config_raw)
# Build kernel from source (handles import path resolution)
kernel_src = Path(__file__).parent.parent
gladius_src = self._find_kernel_source(kernel_src)
import sys
if str(gladius_src) not in sys.path:
sys.path.insert(0, str(gladius_src))
from kernel import GladiusKernel
from kernel.config import KernelConfig
# Filter config to valid KernelConfig fields
valid_fields = {f.name for f in dataclasses.fields(KernelConfig)}
filtered = {k: v for k, v in config_dict.items() if k in valid_fields}
# Handle dtype serialization
if 'dtype' in filtered:
dtype_val = filtered['dtype']
if isinstance(dtype_val, str):
filtered['dtype'] = getattr(torch, dtype_val.replace('torch.', ''), torch.float32)
elif not isinstance(dtype_val, torch.dtype):
filtered['dtype'] = torch.float32
# Ensure cold_embedding_dim matches hidden_dim
if 'cold_embedding_dim' not in filtered or filtered.get('cold_embedding_dim') != filtered.get('hidden_dim'):
filtered['cold_embedding_dim'] = filtered.get('hidden_dim', 640)
config = KernelConfig(**filtered)
self.kernel = GladiusKernel(config)
# Load model weights (strict=False for optional components)
state_dict = ckpt.get('model_state_dict', ckpt.get('state_dict', {}))
self.kernel.load_state_dict(state_dict, strict=False)
# Apply synthase upgrade if checkpoint indicates it
self._has_synthase = bool(ckpt.get('synthase', False))
if self._has_synthase:
try:
from synthase.synthase_surgery import upgrade_to_synthase
upgrade_to_synthase(self.kernel)
# Reload weights to pick up synthase parameters
self.kernel.load_state_dict(state_dict, strict=False)
except ImportError:
print("Warning: Checkpoint has synthase but synthase_surgery not found. Skipping.")
self._has_synthase = False
# Apply PUP if checkpoint indicates it
self._has_pup = bool(ckpt.get('pup', False))
self.pup_head = None
if self._has_pup:
try:
from pup.pup_surgery import upgrade_kernel_to_pup
upgrade_kernel_to_pup(self.kernel)
self.pup_head = self.kernel.pup_head
# PUP weights are already in state_dict under pup_head.*
# They were loaded with the kernel load_state_dict above
except ImportError:
print("Warning: Checkpoint has PUP but pup_surgery not found. Skipping.")
self._has_pup = False
# Freeze GLADIUS kernel (the whole point)
if freeze_gladius:
for p in self.kernel.parameters():
p.requires_grad = False
self.kernel.eval()
# Extract dimensions from loaded config
self.gladius_dim = config.hidden_dim
self.num_layers = config.num_layers
self.max_seq_len = config.max_seq_len
self.config = config
self._step = ckpt.get('step', 0)
self._frozen = freeze_gladius
# Create membrane — the ONLY learned component
self.membrane = Membrane(external_dim, self.gladius_dim)
# Move to device
self.to(device)
self._report()
def _find_kernel_source(self, start: Path) -> Path:
"""
Find the GLADIUS kernel source directory.
Searches upward from plug/ for a directory containing kernel/kernel.py.
Falls back to gladius_v2/src/ if available.
"""
# Check if we're inside gladius_v2/staging/kernel/plug/
# -> parent = plug, parent.parent = kernel, parent.parent.parent = staging
# but the actual kernel.py with GladiusKernel is in gladius_v2/src/
# Strategy: walk up looking for src/kernel/kernel.py
current = start
for _ in range(6):
candidate = current / 'src'
if (candidate / 'kernel' / 'kernel.py').exists():
return str(candidate)
current = current.parent
# Fallback: check gladius_v2 relative to workspace
workspace = Path(os.environ.get('GLADIUS_WORKSPACE', '.'))
gladius_src = workspace / 'gladius_v2' / 'src'
if (gladius_src / 'kernel' / 'kernel.py').exists():
return str(gladius_src)
raise ImportError(
"Cannot find GLADIUS kernel source (kernel/kernel.py). "
"Expected in gladius_v2/src/ or parent directories of plug/."
)
def forward(
self,
external_hidden_states: torch.Tensor,
return_pup: bool = True,
) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
"""
Project external representations through the GLADIUS cognitive stack.
Args:
external_hidden_states: [batch, seq_len, external_dim]
Hidden states from any external model (GPT-2, Qwen, VLM, etc.)
return_pup: whether to compute PUP uncertainty manifold
Returns:
enriched: [batch, seq_len, gladius_dim] — depth-enriched representations
pup_manifold: dict with mu, sigma, confidence, log_var (or None)
"""
B, S, _ = external_hidden_states.shape
# Truncate to GLADIUS max sequence length
if S > self.max_seq_len:
external_hidden_states = external_hidden_states[:, :self.max_seq_len, :]
S = self.max_seq_len
# Project through membrane (external_dim → gladius_dim)
x = self.membrane(external_hidden_states)
# Run through GLADIUS transformer layers (bypassing embedding)
enriched = self._forward_through_layers(x)
# PUP uncertainty manifold
pup_manifold = None
if return_pup and self.pup_head is not None:
pup_manifold = self.pup_head(hidden=enriched)
return enriched, pup_manifold
def _forward_through_layers(self, x: torch.Tensor) -> torch.Tensor:
"""
Run through GLADIUS transformer layer stack, bypassing token embedding.
Handles both standard and synthase-upgraded layers.
Builds causal mask matching the kernel's expected format.
"""
B, S, D = x.shape
# Build causal mask (same format as GladiusKernel.forward)
if S <= self.max_seq_len and hasattr(self.kernel, 'causal_mask'):
mask = self.kernel.causal_mask[:, :, :S, :S]
else:
mask = torch.tril(torch.ones(1, 1, S, S, device=x.device))
# Run through each transformer layer
for layer in self.kernel.layers:
x = layer(x, mask=mask)
# Final norm
if hasattr(self.kernel, 'final_norm'):
x = self.kernel.final_norm(x)
return x
def membrane_params(self):
"""Return only membrane parameters (for optimizer)."""
return self.membrane.parameters()
def membrane_param_count(self) -> int:
"""Count of trainable membrane parameters."""
return sum(p.numel() for p in self.membrane.parameters())
def kernel_param_count(self) -> int:
"""Count of frozen kernel parameters."""
return sum(p.numel() for p in self.kernel.parameters())
def save_membrane(self, path: str):
"""Save only the membrane weights (tiny file)."""
torch.save({
'membrane_state_dict': self.membrane.state_dict(),
'external_dim': self.membrane.external_dim,
'gladius_dim': self.membrane.gladius_dim,
'kernel_step': self._step,
}, path)
print(f"Membrane saved: {path} ({self.membrane_param_count():,} params)")
def load_membrane(self, path: str):
"""Load membrane weights from file."""
data = torch.load(path, map_location='cpu')
state = data.get('membrane_state_dict', data)
self.membrane.load_state_dict(state)
print(f"Membrane loaded: {path}")
def _report(self):
"""Print Plug configuration summary."""
membrane_p = self.membrane_param_count()
kernel_p = self.kernel_param_count()
total_p = membrane_p + kernel_p
print(f"\n{'='*55}")
print(f" GLADIUS PLUG — Cognitive Adapter")
print(f"{'='*55}")
print(f" Kernel: {kernel_p:>12,} params (frozen={self._frozen})")
print(f" Membrane: {membrane_p:>12,} params (TRAINABLE)")
print(f" Total: {total_p:>12,} params")
print(f" Overhead: {membrane_p/kernel_p*100:.3f}%")
print(f" External dim: {self.membrane.external_dim}")
print(f" GLADIUS dim: {self.gladius_dim}")
print(f" Layers: {self.num_layers}")
print(f" Synthase: {'yes' if self._has_synthase else 'no'}")
print(f" PUP: {'yes' if self._has_pup else 'no'}")
print(f" From step: {self._step:,}")
print(f"{'='*55}\n")