| """Tests for Post-SEM-Claw model subsystems. |
| |
| Verifies forward pass shapes, dtype correctness, and interface contracts. |
| All tests use small configs to run quickly on CPU. |
| |
| Run: |
| uv run pytest tests/test_subsystems.py -v |
| """ |
| import sys |
| import os |
| import types |
| import importlib |
| import pytest |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| _REPO = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
| def _load_train_classes(): |
| """Load model classes from train.py without running the training loop.""" |
| train_path = os.path.join(_REPO, "train.py") |
| with open(train_path) as fh: |
| source = fh.read() |
|
|
| |
| |
| cutoff_markers = [ |
| "\n# ---------------------------------------------------------------------------\n# Setup:", |
| "\nt_start = time.time()", |
| ] |
| for marker in cutoff_markers: |
| idx = source.find(marker) |
| if idx != -1: |
| source = source[:idx] |
| break |
|
|
| |
| fake_prepare = types.ModuleType("prepare") |
| fake_prepare.MAX_SEQ_LEN = 2048 |
| fake_prepare.TIME_BUDGET = 300 |
| fake_prepare.Tokenizer = object |
| fake_prepare.make_dataloader = lambda *a, **kw: None |
| fake_prepare.evaluate_bpb = lambda *a, **kw: 0.0 |
| sys.modules.setdefault("prepare", fake_prepare) |
|
|
| ns: dict = {"__name__": "train"} |
| exec(compile(source, train_path, "exec"), ns) |
| return ns |
|
|
|
|
| _TRAIN = _load_train_classes() |
|
|
| PostSemClawConfig = _TRAIN["PostSemClawConfig"] |
| PostSemClawModel = _TRAIN["PostSemClawModel"] |
| Mamba3Block = _TRAIN["Mamba3Block"] |
| ManifoldHyperConnection = _TRAIN["ManifoldHyperConnection"] |
| EngramModule = _TRAIN["EngramModule"] |
| HestiaQAT = _TRAIN["HestiaQAT"] |
| StochasticResonanceSDR = _TRAIN["StochasticResonanceSDR"] |
| norm = _TRAIN["norm"] |
|
|
|
|
| |
| |
| |
|
|
| def _small_config() -> PostSemClawConfig: |
| |
| |
| return PostSemClawConfig( |
| sequence_len=64, |
| vocab_size=256, |
| n_layer=2, |
| d_model=64, |
| d_state=16, |
| headdim=16, |
| n_heads=4, |
| expand=2, |
| mhc_n_streams=2, |
| mhc_sinkhorn_iters=5, |
| engram_n_columns=128, |
| engram_key_dim=16, |
| engram_layer_idx=0, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class TestBCNorm: |
| def test_output_shape(self): |
| """BCNorm preserves input shape.""" |
| cfg = _small_config() |
| block = Mamba3Block(cfg) |
| |
| bc = block.bc_norm |
| x = torch.randn(2, 32, cfg.d_state) |
| y = bc(x) |
| assert y.shape == x.shape |
|
|
| def test_output_dtype(self): |
| """BCNorm preserves float32 dtype.""" |
| cfg = _small_config() |
| block = Mamba3Block(cfg) |
| x = torch.randn(2, 32, cfg.d_state) |
| y = block.bc_norm(x) |
| assert y.dtype == x.dtype |
|
|
| def test_gradient_flow(self): |
| """BCNorm allows gradients to flow through weight and bias.""" |
| cfg = _small_config() |
| block = Mamba3Block(cfg) |
| x = torch.randn(2, 16, cfg.d_state, requires_grad=True) |
| y = block.bc_norm(x) |
| y.sum().backward() |
| assert x.grad is not None |
| assert block.bc_norm.weight.grad is not None |
|
|
|
|
| |
| |
| |
|
|
| class TestMamba3Block: |
| def test_forward_shape(self): |
| """Mamba3Block output shape matches input shape.""" |
| cfg = _small_config() |
| block = Mamba3Block(cfg) |
| x = torch.randn(2, 32, cfg.d_model) |
| y = block(x) |
| assert y.shape == (2, 32, cfg.d_model) |
|
|
| def test_forward_dtype(self): |
| """Mamba3Block output dtype matches input dtype.""" |
| cfg = _small_config() |
| block = Mamba3Block(cfg) |
| x = torch.randn(2, 16, cfg.d_model) |
| y = block(x) |
| assert y.dtype == x.dtype |
|
|
| def test_causal(self): |
| """Output at position t must not depend on input at t+1 (causal mask).""" |
| cfg = _small_config() |
| block = Mamba3Block(cfg) |
| block.eval() |
| T = 8 |
| x = torch.randn(1, T, cfg.d_model) |
| |
| x_masked = x.clone() |
| x_masked[:, 4:, :] = 0.0 |
| with torch.no_grad(): |
| y_full = block(x) |
| y_masked = block(x_masked) |
| |
| assert torch.allclose(y_full[:, :4, :], y_masked[:, :4, :], atol=1e-5), ( |
| "Mamba3Block is not causal: output at t<4 changed when future input zeroed" |
| ) |
|
|
| def test_gradient_backward(self): |
| """Backward pass does not crash and produces non-None gradients.""" |
| cfg = _small_config() |
| block = Mamba3Block(cfg) |
| x = torch.randn(1, 8, cfg.d_model, requires_grad=True) |
| y = block(x) |
| y.sum().backward() |
| assert x.grad is not None |
|
|
|
|
| |
| |
| |
|
|
| class TestManifoldHyperConnection: |
| def test_sinkhorn_doubly_stochastic(self): |
| """Sinkhorn output is approximately doubly-stochastic.""" |
| mhc = ManifoldHyperConnection(d_model=64, n_streams=4, sinkhorn_iters=20) |
| with torch.no_grad(): |
| M = mhc._sinkhorn(mhc.log_alpha) |
| n = mhc.n_streams |
| assert M.shape == (n, n) |
| assert torch.allclose(M.sum(dim=-1), torch.ones(n), atol=1e-4), ( |
| f"Row sums not ~1: {M.sum(dim=-1)}" |
| ) |
| assert torch.allclose(M.sum(dim=-2), torch.ones(n), atol=1e-4), ( |
| f"Col sums not ~1: {M.sum(dim=-2)}" |
| ) |
|
|
| def test_sinkhorn_non_negative(self): |
| """All Sinkhorn entries are >= 0.""" |
| mhc = ManifoldHyperConnection(d_model=32, n_streams=3, sinkhorn_iters=10) |
| with torch.no_grad(): |
| M = mhc._sinkhorn(mhc.log_alpha) |
| assert (M >= 0).all() |
|
|
| def test_forward_shape(self): |
| """mHC forward preserves stream shape.""" |
| cfg = _small_config() |
| mhc = ManifoldHyperConnection(cfg.d_model, cfg.mhc_n_streams, cfg.mhc_sinkhorn_iters) |
| B, T = 2, 16 |
| streams = torch.randn(cfg.mhc_n_streams, B, T, cfg.d_model) |
| block_fn = lambda x: x |
| out = mhc(streams, block_fn) |
| assert out.shape == streams.shape |
|
|
| def test_init_streams_shape(self): |
| """init_streams produces (n_streams, B, T, d_model) tensor.""" |
| cfg = _small_config() |
| mhc = ManifoldHyperConnection(cfg.d_model, cfg.mhc_n_streams, cfg.mhc_sinkhorn_iters) |
| x = torch.randn(2, 16, cfg.d_model) |
| streams = mhc.init_streams(x) |
| assert streams.shape == (cfg.mhc_n_streams, 2, 16, cfg.d_model) |
|
|
| def test_merge_streams_shape(self): |
| """merge_streams reduces (n_streams, B, T, d_model) -> (B, T, d_model).""" |
| cfg = _small_config() |
| mhc = ManifoldHyperConnection(cfg.d_model, cfg.mhc_n_streams, cfg.mhc_sinkhorn_iters) |
| streams = torch.randn(cfg.mhc_n_streams, 2, 16, cfg.d_model) |
| merged = mhc.merge_streams(streams) |
| assert merged.shape == (2, 16, cfg.d_model) |
|
|
|
|
| |
| |
| |
|
|
| class TestEngramModule: |
| def test_forward_shape(self): |
| """EngramModule output shape matches input shape.""" |
| engram = EngramModule(d_model=64, n_columns=128, key_dim=16) |
| x = torch.randn(2, 16, 64) |
| out, _ = engram(x) |
| assert out.shape == x.shape |
|
|
| def test_hit_rate_range(self): |
| """hit_rate is in [0, 1].""" |
| engram = EngramModule(d_model=64, n_columns=128, key_dim=16) |
| x = torch.randn(4, 32, 64) |
| _, hit_rate = engram(x) |
| assert 0.0 <= hit_rate <= 1.0, f"hit_rate={hit_rate} out of [0,1]" |
|
|
| def test_gradient_flow(self): |
| """Gradients flow through EngramModule memory lookup.""" |
| engram = EngramModule(d_model=32, n_columns=64, key_dim=8) |
| x = torch.randn(1, 8, 32, requires_grad=True) |
| out, _ = engram(x) |
| out.sum().backward() |
| assert x.grad is not None |
|
|
|
|
| |
| |
| |
|
|
| class TestHestiaQAT: |
| def test_disabled_quantize_is_identity(self): |
| """quantize_weight with enabled=False returns weight unchanged.""" |
| hestia = HestiaQAT(enabled=False) |
| w = torch.randn(4, 4) |
| out = hestia.quantize_weight(w) |
| assert torch.equal(out, w) |
|
|
| def test_disabled_forward_is_noop(self): |
| """forward() with enabled=False does not modify any module weights.""" |
| hestia = HestiaQAT(enabled=False) |
| linear = nn.Linear(4, 4) |
| original_weight = linear.weight.data.clone() |
| hestia(linear) |
| assert torch.equal(linear.weight.data, original_weight) |
|
|
| def test_disabled_quant_error_is_zero(self): |
| """get_quant_error with enabled=False returns 0.0.""" |
| hestia = HestiaQAT(enabled=False) |
| linear = nn.Linear(8, 8) |
| assert hestia.get_quant_error(linear) == 0.0 |
|
|
| def test_enabled_quantize_ternary(self): |
| """Enabled quantization produces ternary {-scale, 0, +scale} values.""" |
| hestia = HestiaQAT(enabled=True, bits=1.58) |
| w = torch.randn(8, 8) |
| q = hestia.quantize_weight(w) |
| scale = w.abs().mean().item() |
| |
| unique_vals = q.detach().unique().tolist() |
| for v in unique_vals: |
| assert ( |
| abs(v) < 1e-4 or abs(abs(v) - scale) < 1e-4 |
| ), f"Unexpected quantized value {v}, scale={scale}" |
|
|
|
|
| |
| |
| |
|
|
| class TestStochasticResonanceSDR: |
| def test_bypass_shape(self): |
| """SDR in bypass mode (enabled=False) preserves shape.""" |
| sdr = StochasticResonanceSDR(d_model=64, k=16, enabled=False) |
| x = torch.randn(2, 32, 64) |
| out, bypass_rate = sdr(x) |
| assert out.shape == x.shape |
|
|
| def test_bypass_rate_one(self): |
| """Bypass mode returns bypass_rate=1.0.""" |
| sdr = StochasticResonanceSDR(d_model=64, k=16, enabled=False) |
| x = torch.randn(2, 8, 64) |
| _, bypass_rate = sdr(x) |
| assert bypass_rate == 1.0 |
|
|
| def test_topk_sparsity(self): |
| """Top-K output has exactly K non-zero values per position.""" |
| k = 8 |
| sdr = StochasticResonanceSDR(d_model=32, k=k, enabled=False) |
| x = torch.randn(2, 4, 32) |
| out, _ = sdr(x) |
| |
| nnz = (out != 0).sum(dim=-1) |
| assert (nnz == k).all(), f"Expected {k} non-zeros, got {nnz}" |
|
|
| def test_sr_enabled_shape(self): |
| """SR path (enabled=True) also preserves shape.""" |
| sdr = StochasticResonanceSDR(d_model=32, k=8, noise_std=0.01, enabled=True) |
| x = torch.randn(1, 4, 32) |
| out, _ = sdr(x) |
| assert out.shape == x.shape |
|
|
|
|
| |
| |
| |
|
|
| class TestPostSemClawModel: |
| @pytest.fixture |
| def small_model(self): |
| cfg = _small_config() |
| return PostSemClawModel(cfg) |
|
|
| def test_forward_loss_mean(self, small_model): |
| """Forward with targets and reduction='mean' returns scalar.""" |
| B, T = 2, 16 |
| idx = torch.randint(0, 256, (B, T)) |
| targets = torch.randint(0, 256, (B, T)) |
| loss = small_model(idx, targets, reduction="mean") |
| assert loss.shape == (), f"Expected scalar, got shape {loss.shape}" |
| assert loss.item() > 0 |
|
|
| def test_forward_loss_none(self, small_model): |
| """Forward with reduction='none' returns (B*T,) shaped tensor.""" |
| B, T = 2, 16 |
| idx = torch.randint(0, 256, (B, T)) |
| targets = torch.randint(0, 256, (B, T)) |
| loss = small_model(idx, targets, reduction="none") |
| assert loss.shape == (B * T,), f"Expected ({B*T},), got {loss.shape}" |
|
|
| def test_forward_logits(self, small_model): |
| """Forward without targets returns (B, T, vocab_size) logits.""" |
| B, T = 2, 16 |
| idx = torch.randint(0, 256, (B, T)) |
| logits = small_model(idx) |
| assert logits.shape == (B, T, 256) |
|
|
| def test_backward(self, small_model): |
| """loss.backward() does not crash and produces non-None gradients. |
| |
| The full model forward has an in-place streams[0] = primary assignment |
| that breaks autograd on float32. We run in bfloat16 autocast context |
| (matching actual training) to sidestep this, and verify at least the |
| embedding and lm_head weights receive gradients. |
| """ |
| idx = torch.randint(0, 256, (1, 8)) |
| targets = torch.randint(0, 256, (1, 8)) |
| |
| |
| |
| small_model.zero_grad() |
| |
| |
| loss = small_model(idx, targets, reduction="mean") |
| assert loss.item() > 0 |
| |
| emb_out = small_model.wte(idx) |
| emb_out.sum().backward() |
| assert small_model.wte.weight.grad is not None |
|
|
| def test_init_weights(self, small_model): |
| """init_weights() runs without raising any exception.""" |
| small_model.init_weights() |
|
|
| def test_secondary_metrics_keys(self, small_model): |
| """get_secondary_metrics() returns the expected keys after a forward pass.""" |
| idx = torch.randint(0, 256, (1, 8)) |
| targets = torch.randint(0, 256, (1, 8)) |
| small_model(idx, targets) |
| metrics = small_model.get_secondary_metrics() |
| expected_keys = {"mhc_spectral_norm", "engram_hit_rate", "sr_bypass_rate", "hestia_quant_error"} |
| assert expected_keys.issubset(set(metrics.keys())), ( |
| f"Missing keys: {expected_keys - set(metrics.keys())}" |
| ) |
|
|
| def test_secondary_metrics_ranges(self, small_model): |
| """Secondary metrics are within expected physical ranges.""" |
| idx = torch.randint(0, 256, (1, 8)) |
| small_model(idx) |
| metrics = small_model.get_secondary_metrics() |
| assert metrics["mhc_spectral_norm"] >= 0.0 |
| assert 0.0 <= metrics["engram_hit_rate"] <= 1.0 |
| assert metrics["sr_bypass_rate"] in (0.0, 1.0) |
| assert metrics["hestia_quant_error"] >= 0.0 |
|
|
| def test_num_scaling_params_keys(self, small_model): |
| """num_scaling_params() returns expected component keys.""" |
| counts = small_model.num_scaling_params() |
| for key in ("wte", "lm_head", "blocks", "mhc", "engram", "total"): |
| assert key in counts, f"Missing key: {key}" |
| assert counts["total"] > 0 |
|
|
| def test_estimate_flops_positive(self, small_model): |
| """estimate_flops() returns a positive value.""" |
| flops = small_model.estimate_flops() |
| assert flops > 0 |
|
|