Gladius / extensions /plug /test_plug.py
amuzetnoM's picture
GLADIUS v5.0 β€” Cognitive kernel with Synthase depth attention, PUP uncertainty, Memory V2, multi-tokenizer architecture
3f42614
"""
GLADIUS Plug β€” Test Suite
Verifies membrane projection, kernel freeze, gradient flow, and save/load.
Runs WITHOUT a real GLADIUS checkpoint by mocking the kernel.
"""
import torch
import torch.nn as nn
import tempfile
import os
import sys
from pathlib import Path
# Import directly from the module file to avoid package __init__.py conflicts
plug_dir = str(Path(__file__).parent)
sys.path.insert(0, str(Path(__file__).parent.parent))
from plug.plug import Membrane
def test_membrane_shape():
"""Membrane projects external_dim β†’ gladius_dim with correct shapes."""
membrane = Membrane(external_dim=768, gladius_dim=640)
x = torch.randn(2, 128, 768)
out = membrane(x)
assert out.shape == (2, 128, 640), f"Expected (2, 128, 640), got {out.shape}"
print("[PASS] test_membrane_shape: (2, 128, 768) -> (2, 128, 640)")
def test_membrane_different_dims():
"""Membrane works with various external dimensions."""
for ext_dim, gladius_dim in [(768, 640), (2048, 640), (256, 640), (4096, 256)]:
membrane = Membrane(external_dim=ext_dim, gladius_dim=gladius_dim)
x = torch.randn(1, 32, ext_dim)
out = membrane(x)
assert out.shape == (1, 32, gladius_dim), \
f"ext={ext_dim}->gladius={gladius_dim}: expected (1,32,{gladius_dim}), got {out.shape}"
print("[PASS] test_membrane_different_dims: all dimension pairs work")
def test_membrane_gradients():
"""Membrane parameters have gradients after backward pass."""
membrane = Membrane(external_dim=768, gladius_dim=640)
x = torch.randn(2, 128, 768)
out = membrane(x)
loss = out.sum()
loss.backward()
for name, param in membrane.named_parameters():
assert param.grad is not None, f"Membrane param '{name}' has no gradient"
assert param.grad.abs().sum() > 0, f"Membrane param '{name}' has zero gradient"
print("[PASS] test_membrane_gradients: all membrane params receive gradients")
def test_freeze_kernel_simulation():
"""
Simulate the freeze behavior: kernel params frozen, membrane params trainable.
Uses a simple nn.Module as kernel stand-in.
"""
# Mock kernel: 3-layer transformer
class MockKernel(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(640, 640) for _ in range(3)])
self.final_norm = nn.LayerNorm(640)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return self.final_norm(x)
kernel = MockKernel()
membrane = Membrane(external_dim=768, gladius_dim=640)
# Freeze kernel
for p in kernel.parameters():
p.requires_grad = False
# Forward through membrane + kernel
x = torch.randn(2, 128, 768)
projected = membrane(x)
enriched = kernel(projected)
loss = enriched.sum()
loss.backward()
# Kernel params: frozen (no grad)
for name, param in kernel.named_parameters():
assert param.grad is None, f"Kernel param '{name}' should have no gradient (frozen)"
# Membrane params: trainable (has grad)
for name, param in membrane.named_parameters():
assert param.grad is not None, f"Membrane param '{name}' should have gradient"
print("[PASS] test_freeze_kernel_simulation: kernel frozen, membrane learns")
def test_membrane_save_load():
"""Membrane state roundtrips through save/load correctly."""
membrane = Membrane(external_dim=768, gladius_dim=640)
# Run a forward + backward to give params non-trivial values
x = torch.randn(2, 64, 768)
out = membrane(x)
loss = out.sum()
loss.backward()
# Manually update weights to make them non-default
with torch.no_grad():
for p in membrane.parameters():
p.add_(torch.randn_like(p) * 0.1)
# Save
with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f:
save_path = f.name
try:
torch.save({
'membrane_state_dict': membrane.state_dict(),
'external_dim': membrane.external_dim,
'gladius_dim': membrane.gladius_dim,
}, save_path)
# Load into new membrane
membrane2 = Membrane(external_dim=768, gladius_dim=640)
data = torch.load(save_path, map_location='cpu')
membrane2.load_state_dict(data['membrane_state_dict'])
# Verify weights match
for (n1, p1), (n2, p2) in zip(membrane.named_parameters(), membrane2.named_parameters()):
assert torch.equal(p1, p2), f"Mismatch on '{n1}' after load"
# Verify outputs match
x_test = torch.randn(1, 32, 768)
membrane.eval()
membrane2.eval()
with torch.no_grad():
out1 = membrane(x_test)
out2 = membrane2(x_test)
assert torch.allclose(out1, out2, atol=1e-6), "Output mismatch after load"
print("[PASS] test_membrane_save_load: roundtrip preserves weights and outputs")
finally:
os.unlink(save_path)
def test_membrane_param_count():
"""Verify parameter count math: Linear(ext, gladius) + LayerNorm(gladius)."""
membrane = Membrane(external_dim=768, gladius_dim=640)
expected_linear = 768 * 640 + 640 # weight + bias
expected_ln = 640 + 640 # weight + bias
expected_total = expected_linear + expected_ln
actual = sum(p.numel() for p in membrane.parameters())
assert actual == expected_total, \
f"Expected {expected_total:,} params, got {actual:,}"
print(f"[PASS] test_membrane_param_count: {actual:,} params (Linear: {expected_linear:,} + LN: {expected_ln:,})")
def test_membrane_layernorm_output():
"""LayerNorm in membrane normalizes the projected output."""
membrane = Membrane(external_dim=768, gladius_dim=640)
x = torch.randn(2, 128, 768) * 100 # Large scale input
out = membrane(x)
# After LayerNorm, last dim should be approximately zero-mean, unit-var
mean = out.mean(dim=-1)
var = out.var(dim=-1, unbiased=False)
assert mean.abs().max() < 0.5, f"Post-LN mean too high: {mean.abs().max():.4f}"
assert (var - 1.0).abs().max() < 0.5, f"Post-LN var far from 1: {var.mean():.4f}"
print("[PASS] test_membrane_layernorm_output: output is approximately normalized")
def test_plug_forward_simulation():
"""
Full Plug-style forward: membrane β†’ frozen layers β†’ output.
Simulates what GladiusPlug.forward() does without needing a checkpoint.
"""
class MockGladiusStack(nn.Module):
"""Simulates GLADIUS layer stack + final norm."""
def __init__(self, dim=640, num_layers=14):
super().__init__()
self.layers = nn.ModuleList([
nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, dim))
for _ in range(num_layers)
])
self.final_norm = nn.LayerNorm(dim)
def forward(self, x):
for layer in self.layers:
x = x + layer(x) # residual
return self.final_norm(x)
# Build Plug-like setup
membrane = Membrane(external_dim=768, gladius_dim=640)
kernel = MockGladiusStack(dim=640, num_layers=14)
# Freeze kernel
for p in kernel.parameters():
p.requires_grad = False
# Forward
external_hidden = torch.randn(2, 128, 768)
projected = membrane(external_hidden)
enriched = kernel(projected)
assert enriched.shape == (2, 128, 640)
# Backward β€” only membrane should get gradients
loss = enriched.sum()
loss.backward()
membrane_grads = sum(1 for p in membrane.parameters() if p.grad is not None)
kernel_grads = sum(1 for p in kernel.parameters() if p.grad is not None)
assert membrane_grads > 0, "Membrane should have gradients"
assert kernel_grads == 0, "Kernel should have zero gradients (frozen)"
print("[PASS] test_plug_forward_simulation: full pipeline, correct gradient flow")
def test_membrane_batch_sizes():
"""Membrane handles various batch sizes including batch=1."""
membrane = Membrane(external_dim=768, gladius_dim=640)
for batch in [1, 2, 4, 16]:
x = torch.randn(batch, 64, 768)
out = membrane(x)
assert out.shape == (batch, 64, 640), f"Batch {batch}: wrong shape {out.shape}"
print("[PASS] test_membrane_batch_sizes: handles batch 1, 2, 4, 16")
# ─── Run all tests ───
if __name__ == '__main__':
tests = [
test_membrane_shape,
test_membrane_different_dims,
test_membrane_gradients,
test_freeze_kernel_simulation,
test_membrane_save_load,
test_membrane_param_count,
test_membrane_layernorm_output,
test_plug_forward_simulation,
test_membrane_batch_sizes,
]
passed = 0
failed = 0
for test in tests:
try:
test()
passed += 1
except Exception as e:
print(f"[FAIL] {test.__name__}: {e}")
failed += 1
print(f"\n{'='*50}")
print(f" {passed}/{passed+failed} tests passed")
if failed > 0:
print(f" {failed} FAILED")
else:
print(f" ALL PASS βœ“")
print(f"{'='*50}")