""" GLADIUS Plug — Test Suite Verifies membrane projection, kernel freeze, gradient flow, and save/load. Runs WITHOUT a real GLADIUS checkpoint by mocking the kernel. """ import torch import torch.nn as nn import tempfile import os import sys from pathlib import Path # Import directly from the module file to avoid package __init__.py conflicts plug_dir = str(Path(__file__).parent) sys.path.insert(0, str(Path(__file__).parent.parent)) from plug.plug import Membrane def test_membrane_shape(): """Membrane projects external_dim → gladius_dim with correct shapes.""" membrane = Membrane(external_dim=768, gladius_dim=640) x = torch.randn(2, 128, 768) out = membrane(x) assert out.shape == (2, 128, 640), f"Expected (2, 128, 640), got {out.shape}" print("[PASS] test_membrane_shape: (2, 128, 768) -> (2, 128, 640)") def test_membrane_different_dims(): """Membrane works with various external dimensions.""" for ext_dim, gladius_dim in [(768, 640), (2048, 640), (256, 640), (4096, 256)]: membrane = Membrane(external_dim=ext_dim, gladius_dim=gladius_dim) x = torch.randn(1, 32, ext_dim) out = membrane(x) assert out.shape == (1, 32, gladius_dim), \ f"ext={ext_dim}->gladius={gladius_dim}: expected (1,32,{gladius_dim}), got {out.shape}" print("[PASS] test_membrane_different_dims: all dimension pairs work") def test_membrane_gradients(): """Membrane parameters have gradients after backward pass.""" membrane = Membrane(external_dim=768, gladius_dim=640) x = torch.randn(2, 128, 768) out = membrane(x) loss = out.sum() loss.backward() for name, param in membrane.named_parameters(): assert param.grad is not None, f"Membrane param '{name}' has no gradient" assert param.grad.abs().sum() > 0, f"Membrane param '{name}' has zero gradient" print("[PASS] test_membrane_gradients: all membrane params receive gradients") def test_freeze_kernel_simulation(): """ Simulate the freeze behavior: kernel params frozen, membrane params trainable. Uses a simple nn.Module as kernel stand-in. """ # Mock kernel: 3-layer transformer class MockKernel(nn.Module): def __init__(self): super().__init__() self.layers = nn.ModuleList([nn.Linear(640, 640) for _ in range(3)]) self.final_norm = nn.LayerNorm(640) def forward(self, x): for layer in self.layers: x = layer(x) return self.final_norm(x) kernel = MockKernel() membrane = Membrane(external_dim=768, gladius_dim=640) # Freeze kernel for p in kernel.parameters(): p.requires_grad = False # Forward through membrane + kernel x = torch.randn(2, 128, 768) projected = membrane(x) enriched = kernel(projected) loss = enriched.sum() loss.backward() # Kernel params: frozen (no grad) for name, param in kernel.named_parameters(): assert param.grad is None, f"Kernel param '{name}' should have no gradient (frozen)" # Membrane params: trainable (has grad) for name, param in membrane.named_parameters(): assert param.grad is not None, f"Membrane param '{name}' should have gradient" print("[PASS] test_freeze_kernel_simulation: kernel frozen, membrane learns") def test_membrane_save_load(): """Membrane state roundtrips through save/load correctly.""" membrane = Membrane(external_dim=768, gladius_dim=640) # Run a forward + backward to give params non-trivial values x = torch.randn(2, 64, 768) out = membrane(x) loss = out.sum() loss.backward() # Manually update weights to make them non-default with torch.no_grad(): for p in membrane.parameters(): p.add_(torch.randn_like(p) * 0.1) # Save with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f: save_path = f.name try: torch.save({ 'membrane_state_dict': membrane.state_dict(), 'external_dim': membrane.external_dim, 'gladius_dim': membrane.gladius_dim, }, save_path) # Load into new membrane membrane2 = Membrane(external_dim=768, gladius_dim=640) data = torch.load(save_path, map_location='cpu') membrane2.load_state_dict(data['membrane_state_dict']) # Verify weights match for (n1, p1), (n2, p2) in zip(membrane.named_parameters(), membrane2.named_parameters()): assert torch.equal(p1, p2), f"Mismatch on '{n1}' after load" # Verify outputs match x_test = torch.randn(1, 32, 768) membrane.eval() membrane2.eval() with torch.no_grad(): out1 = membrane(x_test) out2 = membrane2(x_test) assert torch.allclose(out1, out2, atol=1e-6), "Output mismatch after load" print("[PASS] test_membrane_save_load: roundtrip preserves weights and outputs") finally: os.unlink(save_path) def test_membrane_param_count(): """Verify parameter count math: Linear(ext, gladius) + LayerNorm(gladius).""" membrane = Membrane(external_dim=768, gladius_dim=640) expected_linear = 768 * 640 + 640 # weight + bias expected_ln = 640 + 640 # weight + bias expected_total = expected_linear + expected_ln actual = sum(p.numel() for p in membrane.parameters()) assert actual == expected_total, \ f"Expected {expected_total:,} params, got {actual:,}" print(f"[PASS] test_membrane_param_count: {actual:,} params (Linear: {expected_linear:,} + LN: {expected_ln:,})") def test_membrane_layernorm_output(): """LayerNorm in membrane normalizes the projected output.""" membrane = Membrane(external_dim=768, gladius_dim=640) x = torch.randn(2, 128, 768) * 100 # Large scale input out = membrane(x) # After LayerNorm, last dim should be approximately zero-mean, unit-var mean = out.mean(dim=-1) var = out.var(dim=-1, unbiased=False) assert mean.abs().max() < 0.5, f"Post-LN mean too high: {mean.abs().max():.4f}" assert (var - 1.0).abs().max() < 0.5, f"Post-LN var far from 1: {var.mean():.4f}" print("[PASS] test_membrane_layernorm_output: output is approximately normalized") def test_plug_forward_simulation(): """ Full Plug-style forward: membrane → frozen layers → output. Simulates what GladiusPlug.forward() does without needing a checkpoint. """ class MockGladiusStack(nn.Module): """Simulates GLADIUS layer stack + final norm.""" def __init__(self, dim=640, num_layers=14): super().__init__() self.layers = nn.ModuleList([ nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, dim)) for _ in range(num_layers) ]) self.final_norm = nn.LayerNorm(dim) def forward(self, x): for layer in self.layers: x = x + layer(x) # residual return self.final_norm(x) # Build Plug-like setup membrane = Membrane(external_dim=768, gladius_dim=640) kernel = MockGladiusStack(dim=640, num_layers=14) # Freeze kernel for p in kernel.parameters(): p.requires_grad = False # Forward external_hidden = torch.randn(2, 128, 768) projected = membrane(external_hidden) enriched = kernel(projected) assert enriched.shape == (2, 128, 640) # Backward — only membrane should get gradients loss = enriched.sum() loss.backward() membrane_grads = sum(1 for p in membrane.parameters() if p.grad is not None) kernel_grads = sum(1 for p in kernel.parameters() if p.grad is not None) assert membrane_grads > 0, "Membrane should have gradients" assert kernel_grads == 0, "Kernel should have zero gradients (frozen)" print("[PASS] test_plug_forward_simulation: full pipeline, correct gradient flow") def test_membrane_batch_sizes(): """Membrane handles various batch sizes including batch=1.""" membrane = Membrane(external_dim=768, gladius_dim=640) for batch in [1, 2, 4, 16]: x = torch.randn(batch, 64, 768) out = membrane(x) assert out.shape == (batch, 64, 640), f"Batch {batch}: wrong shape {out.shape}" print("[PASS] test_membrane_batch_sizes: handles batch 1, 2, 4, 16") # ─── Run all tests ─── if __name__ == '__main__': tests = [ test_membrane_shape, test_membrane_different_dims, test_membrane_gradients, test_freeze_kernel_simulation, test_membrane_save_load, test_membrane_param_count, test_membrane_layernorm_output, test_plug_forward_simulation, test_membrane_batch_sizes, ] passed = 0 failed = 0 for test in tests: try: test() passed += 1 except Exception as e: print(f"[FAIL] {test.__name__}: {e}") failed += 1 print(f"\n{'='*50}") print(f" {passed}/{passed+failed} tests passed") if failed > 0: print(f" {failed} FAILED") else: print(f" ALL PASS ✓") print(f"{'='*50}")