File size: 13,487 Bytes
3f42614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bec07e
3f42614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
"""
GLADIUS Plug β€” Cognitive adapter for external models.

Any model can rent GLADIUS's 170M cognitive parameters through a learned membrane.

The idea: a frozen LLM (GPT-2, Qwen, any VLM) produces hidden states.
Those hidden states project through a thin learned membrane into GLADIUS's
hidden dimension, then flow through the full GLADIUS layer stack β€”
depth cache, synthase gates, attention, memory β€” emerging as
cognitively enriched representations with a PUP uncertainty manifold.

Only the membrane learns. GLADIUS stays frozen. The mind stays the same.
The skin is swappable.

"There is no such thing as multi-modal." β€” Ali

Architecture:
    External Model (frozen) β†’ hidden_states [B, S, ext_dim]
        β†’ Membrane (learned) β†’ [B, S, 640]
            β†’ GLADIUS Layers (frozen) β†’ [B, S, 640]
                β†’ PUP Head (frozen) β†’ uncertainty manifold (ΞΌ, σ², c)

The membrane is the only learned component: external_dim Γ— 640 + 640 (LayerNorm).
For GPT-2 (768β†’640): 492,160 params. For Qwen-1.7B (2048β†’640): 1,312,000 params.
Everything else: frozen cognitive infrastructure.

Authors: Ali A. Shakil, Ava Shakil
Date: March 31, 2026
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from typing import Optional, Dict, Tuple
import dataclasses


class Membrane(nn.Module):
    """
    Learned projection: external_dim β†’ GLADIUS hidden_dim.
    
    This is the only trainable component in a Plug setup.
    It learns to translate another model's representation space
    into GLADIUS's native cognitive dimension.
    
    Architecture: Linear(ext_dim, gladius_dim) + LayerNorm(gladius_dim)
    """
    
    def __init__(self, external_dim: int, gladius_dim: int = 640):
        super().__init__()
        self.proj = nn.Linear(external_dim, gladius_dim)
        self.norm = nn.LayerNorm(gladius_dim)
        self.external_dim = external_dim
        self.gladius_dim = gladius_dim
        self._init_weights()
    
    def _init_weights(self):
        """Xavier init for smooth gradient flow at startup."""
        nn.init.xavier_uniform_(self.proj.weight)
        nn.init.zeros_(self.proj.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [batch, seq_len, external_dim] from any external model
        Returns:
            [batch, seq_len, gladius_dim] ready for GLADIUS layer stack
        """
        return self.norm(self.proj(x))


class GladiusPlug(nn.Module):
    """
    Wraps a trained GLADIUS kernel as a frozen cognitive adapter.
    
    The Plug loads a GLADIUS checkpoint, freezes it, and exposes its
    transformer layer stack through a learned membrane. External models
    produce hidden states β†’ membrane projects to GLADIUS dim β†’ layers
    process with depth cache and attention β†’ PUP reads uncertainty.
    
    Usage:
        plug = GladiusPlug("checkpoint.pt", external_dim=768)
        enriched, pup_manifold = plug(gpt2_hidden_states)
        
        # Only membrane trains
        optimizer = torch.optim.Adam(plug.membrane_params(), lr=1e-4)
    """
    
    def __init__(
        self,
        checkpoint_path: str,
        external_dim: int,
        freeze_gladius: bool = True,
        device: str = 'cpu',
    ):
        super().__init__()
        
        checkpoint_path = Path(checkpoint_path)
        if not checkpoint_path.exists():
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
        
        # Load checkpoint
        ckpt = torch.load(str(checkpoint_path), map_location=device, weights_only=False)
        
        # Extract config β€” handle both dataclass and dict forms
        config_raw = ckpt.get('config')
        if config_raw is None:
            raise ValueError("Checkpoint missing 'config' key")
        
        if dataclasses.is_dataclass(config_raw) and not isinstance(config_raw, type):
            config_dict = dataclasses.asdict(config_raw)
        elif isinstance(config_raw, dict):
            config_dict = config_raw
        else:
            config_dict = dict(config_raw)
        
        # Build kernel from source (handles import path resolution)
        kernel_src = Path(__file__).parent.parent
        gladius_src = self._find_kernel_source(kernel_src)
        
        import sys
        if str(gladius_src) not in sys.path:
            sys.path.insert(0, str(gladius_src))
        
        from kernel import GladiusKernel
        from kernel.config import KernelConfig
        
        # Filter config to valid KernelConfig fields
        valid_fields = {f.name for f in dataclasses.fields(KernelConfig)}
        filtered = {k: v for k, v in config_dict.items() if k in valid_fields}
        
        # Handle dtype serialization
        if 'dtype' in filtered:
            dtype_val = filtered['dtype']
            if isinstance(dtype_val, str):
                filtered['dtype'] = getattr(torch, dtype_val.replace('torch.', ''), torch.float32)
            elif not isinstance(dtype_val, torch.dtype):
                filtered['dtype'] = torch.float32
        
        # Ensure cold_embedding_dim matches hidden_dim
        if 'cold_embedding_dim' not in filtered or filtered.get('cold_embedding_dim') != filtered.get('hidden_dim'):
            filtered['cold_embedding_dim'] = filtered.get('hidden_dim', 640)
        
        config = KernelConfig(**filtered)
        self.kernel = GladiusKernel(config)
        
        # Load model weights (strict=False for optional components)
        state_dict = ckpt.get('model_state_dict', ckpt.get('state_dict', {}))
        self.kernel.load_state_dict(state_dict, strict=False)
        
        # Apply synthase upgrade if checkpoint indicates it
        self._has_synthase = bool(ckpt.get('synthase', False))
        if self._has_synthase:
            try:
                from synthase.synthase_surgery import upgrade_to_synthase
                upgrade_to_synthase(self.kernel)
                # Reload weights to pick up synthase parameters
                self.kernel.load_state_dict(state_dict, strict=False)
            except ImportError:
                print("Warning: Checkpoint has synthase but synthase_surgery not found. Skipping.")
                self._has_synthase = False
        
        # Apply PUP if checkpoint indicates it
        self._has_pup = bool(ckpt.get('pup', False))
        self.pup_head = None
        if self._has_pup:
            try:
                from pup.pup_surgery import upgrade_kernel_to_pup
                upgrade_kernel_to_pup(self.kernel)
                self.pup_head = self.kernel.pup_head
                # PUP weights are already in state_dict under pup_head.*
                # They were loaded with the kernel load_state_dict above
            except ImportError:
                print("Warning: Checkpoint has PUP but pup_surgery not found. Skipping.")
                self._has_pup = False
        
        # Freeze GLADIUS kernel (the whole point)
        if freeze_gladius:
            for p in self.kernel.parameters():
                p.requires_grad = False
            self.kernel.eval()
        
        # Extract dimensions from loaded config
        self.gladius_dim = config.hidden_dim
        self.num_layers = config.num_layers
        self.max_seq_len = config.max_seq_len
        self.config = config
        self._step = ckpt.get('step', 0)
        self._frozen = freeze_gladius
        
        # Create membrane β€” the ONLY learned component
        self.membrane = Membrane(external_dim, self.gladius_dim)
        
        # Move to device
        self.to(device)
        
        self._report()
    
    def _find_kernel_source(self, start: Path) -> Path:
        """
        Find the GLADIUS kernel source directory.
        Searches upward from plug/ for a directory containing kernel/kernel.py.
        Falls back to gladius_v2/src/ if available.
        """
        # Check if we're inside gladius_v2/staging/kernel/plug/
        # -> parent = plug, parent.parent = kernel, parent.parent.parent = staging
        #    but the actual kernel.py with GladiusKernel is in gladius_v2/src/
        
        # Strategy: walk up looking for src/kernel/kernel.py
        current = start
        for _ in range(6):
            candidate = current / 'src'
            if (candidate / 'kernel' / 'kernel.py').exists():
                return str(candidate)
            current = current.parent
        
        # Fallback: check gladius_v2 relative to workspace
        workspace = Path(os.environ.get('GLADIUS_WORKSPACE', '.'))
        gladius_src = workspace / 'gladius_v2' / 'src'
        if (gladius_src / 'kernel' / 'kernel.py').exists():
            return str(gladius_src)
        
        raise ImportError(
            "Cannot find GLADIUS kernel source (kernel/kernel.py). "
            "Expected in gladius_v2/src/ or parent directories of plug/."
        )
    
    def forward(
        self,
        external_hidden_states: torch.Tensor,
        return_pup: bool = True,
    ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
        """
        Project external representations through the GLADIUS cognitive stack.
        
        Args:
            external_hidden_states: [batch, seq_len, external_dim] 
                Hidden states from any external model (GPT-2, Qwen, VLM, etc.)
            return_pup: whether to compute PUP uncertainty manifold
        
        Returns:
            enriched: [batch, seq_len, gladius_dim] β€” depth-enriched representations
            pup_manifold: dict with mu, sigma, confidence, log_var (or None)
        """
        B, S, _ = external_hidden_states.shape
        
        # Truncate to GLADIUS max sequence length
        if S > self.max_seq_len:
            external_hidden_states = external_hidden_states[:, :self.max_seq_len, :]
            S = self.max_seq_len
        
        # Project through membrane (external_dim β†’ gladius_dim)
        x = self.membrane(external_hidden_states)
        
        # Run through GLADIUS transformer layers (bypassing embedding)
        enriched = self._forward_through_layers(x)
        
        # PUP uncertainty manifold
        pup_manifold = None
        if return_pup and self.pup_head is not None:
            pup_manifold = self.pup_head(hidden=enriched)
        
        return enriched, pup_manifold
    
    def _forward_through_layers(self, x: torch.Tensor) -> torch.Tensor:
        """
        Run through GLADIUS transformer layer stack, bypassing token embedding.
        
        Handles both standard and synthase-upgraded layers.
        Builds causal mask matching the kernel's expected format.
        """
        B, S, D = x.shape
        
        # Build causal mask (same format as GladiusKernel.forward)
        if S <= self.max_seq_len and hasattr(self.kernel, 'causal_mask'):
            mask = self.kernel.causal_mask[:, :, :S, :S]
        else:
            mask = torch.tril(torch.ones(1, 1, S, S, device=x.device))
        
        # Run through each transformer layer
        for layer in self.kernel.layers:
            x = layer(x, mask=mask)
        
        # Final norm
        if hasattr(self.kernel, 'final_norm'):
            x = self.kernel.final_norm(x)
        
        return x
    
    def membrane_params(self):
        """Return only membrane parameters (for optimizer)."""
        return self.membrane.parameters()
    
    def membrane_param_count(self) -> int:
        """Count of trainable membrane parameters."""
        return sum(p.numel() for p in self.membrane.parameters())
    
    def kernel_param_count(self) -> int:
        """Count of frozen kernel parameters."""
        return sum(p.numel() for p in self.kernel.parameters())
    
    def save_membrane(self, path: str):
        """Save only the membrane weights (tiny file)."""
        torch.save({
            'membrane_state_dict': self.membrane.state_dict(),
            'external_dim': self.membrane.external_dim,
            'gladius_dim': self.membrane.gladius_dim,
            'kernel_step': self._step,
        }, path)
        print(f"Membrane saved: {path} ({self.membrane_param_count():,} params)")
    
    def load_membrane(self, path: str):
        """Load membrane weights from file."""
        data = torch.load(path, map_location='cpu')
        state = data.get('membrane_state_dict', data)
        self.membrane.load_state_dict(state)
        print(f"Membrane loaded: {path}")
    
    def _report(self):
        """Print Plug configuration summary."""
        membrane_p = self.membrane_param_count()
        kernel_p = self.kernel_param_count()
        total_p = membrane_p + kernel_p
        
        print(f"\n{'='*55}")
        print(f"  GLADIUS PLUG β€” Cognitive Adapter")
        print(f"{'='*55}")
        print(f"  Kernel:       {kernel_p:>12,} params (frozen={self._frozen})")
        print(f"  Membrane:     {membrane_p:>12,} params (TRAINABLE)")
        print(f"  Total:        {total_p:>12,} params")
        print(f"  Overhead:     {membrane_p/kernel_p*100:.3f}%")
        print(f"  External dim: {self.membrane.external_dim}")
        print(f"  GLADIUS dim:  {self.gladius_dim}")
        print(f"  Layers:       {self.num_layers}")
        print(f"  Synthase:     {'yes' if self._has_synthase else 'no'}")
        print(f"  PUP:          {'yes' if self._has_pup else 'no'}")
        print(f"  From step:    {self._step:,}")
        print(f"{'='*55}\n")