| """ |
| GLADIUS Plug β Test Suite |
| |
| Verifies membrane projection, kernel freeze, gradient flow, and save/load. |
| Runs WITHOUT a real GLADIUS checkpoint by mocking the kernel. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import tempfile |
| import os |
| import sys |
| from pathlib import Path |
|
|
| |
| plug_dir = str(Path(__file__).parent) |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from plug.plug import Membrane |
|
|
|
|
| def test_membrane_shape(): |
| """Membrane projects external_dim β gladius_dim with correct shapes.""" |
| membrane = Membrane(external_dim=768, gladius_dim=640) |
| |
| x = torch.randn(2, 128, 768) |
| out = membrane(x) |
| |
| assert out.shape == (2, 128, 640), f"Expected (2, 128, 640), got {out.shape}" |
| print("[PASS] test_membrane_shape: (2, 128, 768) -> (2, 128, 640)") |
|
|
|
|
| def test_membrane_different_dims(): |
| """Membrane works with various external dimensions.""" |
| for ext_dim, gladius_dim in [(768, 640), (2048, 640), (256, 640), (4096, 256)]: |
| membrane = Membrane(external_dim=ext_dim, gladius_dim=gladius_dim) |
| x = torch.randn(1, 32, ext_dim) |
| out = membrane(x) |
| assert out.shape == (1, 32, gladius_dim), \ |
| f"ext={ext_dim}->gladius={gladius_dim}: expected (1,32,{gladius_dim}), got {out.shape}" |
| print("[PASS] test_membrane_different_dims: all dimension pairs work") |
|
|
|
|
| def test_membrane_gradients(): |
| """Membrane parameters have gradients after backward pass.""" |
| membrane = Membrane(external_dim=768, gladius_dim=640) |
| |
| x = torch.randn(2, 128, 768) |
| out = membrane(x) |
| loss = out.sum() |
| loss.backward() |
| |
| for name, param in membrane.named_parameters(): |
| assert param.grad is not None, f"Membrane param '{name}' has no gradient" |
| assert param.grad.abs().sum() > 0, f"Membrane param '{name}' has zero gradient" |
| print("[PASS] test_membrane_gradients: all membrane params receive gradients") |
|
|
|
|
| def test_freeze_kernel_simulation(): |
| """ |
| Simulate the freeze behavior: kernel params frozen, membrane params trainable. |
| Uses a simple nn.Module as kernel stand-in. |
| """ |
| |
| class MockKernel(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layers = nn.ModuleList([nn.Linear(640, 640) for _ in range(3)]) |
| self.final_norm = nn.LayerNorm(640) |
| def forward(self, x): |
| for layer in self.layers: |
| x = layer(x) |
| return self.final_norm(x) |
| |
| kernel = MockKernel() |
| membrane = Membrane(external_dim=768, gladius_dim=640) |
| |
| |
| for p in kernel.parameters(): |
| p.requires_grad = False |
| |
| |
| x = torch.randn(2, 128, 768) |
| projected = membrane(x) |
| enriched = kernel(projected) |
| loss = enriched.sum() |
| loss.backward() |
| |
| |
| for name, param in kernel.named_parameters(): |
| assert param.grad is None, f"Kernel param '{name}' should have no gradient (frozen)" |
| |
| |
| for name, param in membrane.named_parameters(): |
| assert param.grad is not None, f"Membrane param '{name}' should have gradient" |
| |
| print("[PASS] test_freeze_kernel_simulation: kernel frozen, membrane learns") |
|
|
|
|
| def test_membrane_save_load(): |
| """Membrane state roundtrips through save/load correctly.""" |
| membrane = Membrane(external_dim=768, gladius_dim=640) |
| |
| |
| x = torch.randn(2, 64, 768) |
| out = membrane(x) |
| loss = out.sum() |
| loss.backward() |
| |
| |
| with torch.no_grad(): |
| for p in membrane.parameters(): |
| p.add_(torch.randn_like(p) * 0.1) |
| |
| |
| with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f: |
| save_path = f.name |
| |
| try: |
| torch.save({ |
| 'membrane_state_dict': membrane.state_dict(), |
| 'external_dim': membrane.external_dim, |
| 'gladius_dim': membrane.gladius_dim, |
| }, save_path) |
| |
| |
| membrane2 = Membrane(external_dim=768, gladius_dim=640) |
| data = torch.load(save_path, map_location='cpu') |
| membrane2.load_state_dict(data['membrane_state_dict']) |
| |
| |
| for (n1, p1), (n2, p2) in zip(membrane.named_parameters(), membrane2.named_parameters()): |
| assert torch.equal(p1, p2), f"Mismatch on '{n1}' after load" |
| |
| |
| x_test = torch.randn(1, 32, 768) |
| membrane.eval() |
| membrane2.eval() |
| with torch.no_grad(): |
| out1 = membrane(x_test) |
| out2 = membrane2(x_test) |
| assert torch.allclose(out1, out2, atol=1e-6), "Output mismatch after load" |
| |
| print("[PASS] test_membrane_save_load: roundtrip preserves weights and outputs") |
| finally: |
| os.unlink(save_path) |
|
|
|
|
| def test_membrane_param_count(): |
| """Verify parameter count math: Linear(ext, gladius) + LayerNorm(gladius).""" |
| membrane = Membrane(external_dim=768, gladius_dim=640) |
| |
| expected_linear = 768 * 640 + 640 |
| expected_ln = 640 + 640 |
| expected_total = expected_linear + expected_ln |
| |
| actual = sum(p.numel() for p in membrane.parameters()) |
| |
| assert actual == expected_total, \ |
| f"Expected {expected_total:,} params, got {actual:,}" |
| print(f"[PASS] test_membrane_param_count: {actual:,} params (Linear: {expected_linear:,} + LN: {expected_ln:,})") |
|
|
|
|
| def test_membrane_layernorm_output(): |
| """LayerNorm in membrane normalizes the projected output.""" |
| membrane = Membrane(external_dim=768, gladius_dim=640) |
| |
| x = torch.randn(2, 128, 768) * 100 |
| out = membrane(x) |
| |
| |
| mean = out.mean(dim=-1) |
| var = out.var(dim=-1, unbiased=False) |
| |
| assert mean.abs().max() < 0.5, f"Post-LN mean too high: {mean.abs().max():.4f}" |
| assert (var - 1.0).abs().max() < 0.5, f"Post-LN var far from 1: {var.mean():.4f}" |
| print("[PASS] test_membrane_layernorm_output: output is approximately normalized") |
|
|
|
|
| def test_plug_forward_simulation(): |
| """ |
| Full Plug-style forward: membrane β frozen layers β output. |
| Simulates what GladiusPlug.forward() does without needing a checkpoint. |
| """ |
| class MockGladiusStack(nn.Module): |
| """Simulates GLADIUS layer stack + final norm.""" |
| def __init__(self, dim=640, num_layers=14): |
| super().__init__() |
| self.layers = nn.ModuleList([ |
| nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, dim)) |
| for _ in range(num_layers) |
| ]) |
| self.final_norm = nn.LayerNorm(dim) |
| |
| def forward(self, x): |
| for layer in self.layers: |
| x = x + layer(x) |
| return self.final_norm(x) |
| |
| |
| membrane = Membrane(external_dim=768, gladius_dim=640) |
| kernel = MockGladiusStack(dim=640, num_layers=14) |
| |
| |
| for p in kernel.parameters(): |
| p.requires_grad = False |
| |
| |
| external_hidden = torch.randn(2, 128, 768) |
| projected = membrane(external_hidden) |
| enriched = kernel(projected) |
| |
| assert enriched.shape == (2, 128, 640) |
| |
| |
| loss = enriched.sum() |
| loss.backward() |
| |
| membrane_grads = sum(1 for p in membrane.parameters() if p.grad is not None) |
| kernel_grads = sum(1 for p in kernel.parameters() if p.grad is not None) |
| |
| assert membrane_grads > 0, "Membrane should have gradients" |
| assert kernel_grads == 0, "Kernel should have zero gradients (frozen)" |
| |
| print("[PASS] test_plug_forward_simulation: full pipeline, correct gradient flow") |
|
|
|
|
| def test_membrane_batch_sizes(): |
| """Membrane handles various batch sizes including batch=1.""" |
| membrane = Membrane(external_dim=768, gladius_dim=640) |
| |
| for batch in [1, 2, 4, 16]: |
| x = torch.randn(batch, 64, 768) |
| out = membrane(x) |
| assert out.shape == (batch, 64, 640), f"Batch {batch}: wrong shape {out.shape}" |
| |
| print("[PASS] test_membrane_batch_sizes: handles batch 1, 2, 4, 16") |
|
|
|
|
| |
|
|
| if __name__ == '__main__': |
| tests = [ |
| test_membrane_shape, |
| test_membrane_different_dims, |
| test_membrane_gradients, |
| test_freeze_kernel_simulation, |
| test_membrane_save_load, |
| test_membrane_param_count, |
| test_membrane_layernorm_output, |
| test_plug_forward_simulation, |
| test_membrane_batch_sizes, |
| ] |
| |
| passed = 0 |
| failed = 0 |
| |
| for test in tests: |
| try: |
| test() |
| passed += 1 |
| except Exception as e: |
| print(f"[FAIL] {test.__name__}: {e}") |
| failed += 1 |
| |
| print(f"\n{'='*50}") |
| print(f" {passed}/{passed+failed} tests passed") |
| if failed > 0: |
| print(f" {failed} FAILED") |
| else: |
| print(f" ALL PASS β") |
| print(f"{'='*50}") |
|
|