Spaces:
Runtime error
Runtime error
| """Tests for GPUEngram Sparse Modern Hopfield retrieval path. | |
| Tests are written first (TDD) against the new matmul-based retrieval. | |
| Run with: pytest tests/test_engram.py -v | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import pytest | |
| import torch | |
| import torch.nn as nn | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _make_engram(d_model: int = 64, n_columns: int = 1024, hebbian_boost: bool = False): | |
| from hydra.engram import GPUEngram | |
| m = GPUEngram(d_model=d_model, n_columns=n_columns, hebbian_boost=hebbian_boost) | |
| m.eval() | |
| return m | |
| # --------------------------------------------------------------------------- | |
| # test_forward_shape | |
| # --------------------------------------------------------------------------- | |
| def test_forward_shape(): | |
| """Output tensor matches input shape; hit_rate is a scalar.""" | |
| B, T, D = 2, 16, 64 | |
| m = _make_engram(d_model=D, n_columns=1024) | |
| x = torch.randn(B, T, D) | |
| token_ids = torch.randint(0, 1000, (B, T)) | |
| out, hit_rate = m(x, token_ids) | |
| assert out.shape == (B, T, D), f"Expected ({B},{T},{D}), got {out.shape}" | |
| assert hit_rate.ndim == 0, f"hit_rate should be scalar, got shape {hit_rate.shape}" | |
| # --------------------------------------------------------------------------- | |
| # test_gradient_flow | |
| # --------------------------------------------------------------------------- | |
| def test_gradient_flow(): | |
| """Backprop through the Hopfield matmul path must reach self.memory.grad. | |
| The old scatter-gather path used self.memory[indices] which DID produce | |
| gradients only for indexed rows. The new path (scores = x @ memory.T then | |
| weights @ memory) creates a full matmul, so every column gets a non-zero | |
| gradient signal (on a random batch where all keys are attended to). | |
| """ | |
| D, N = 64, 128 | |
| m = _make_engram(d_model=D, n_columns=N) | |
| m.train() | |
| x = torch.randn(2, 8, D, requires_grad=True) | |
| token_ids = torch.randint(0, 100, (2, 8)) | |
| out, _ = m(x, token_ids) | |
| loss = out.sum() | |
| loss.backward() | |
| assert m.memory.grad is not None, "self.memory.grad must be non-None after backward" | |
| assert m.memory.grad.abs().sum() > 0, "self.memory.grad must have non-zero entries" | |
| # --------------------------------------------------------------------------- | |
| # test_sparsity | |
| # --------------------------------------------------------------------------- | |
| def test_sparsity(): | |
| """At least 95% of alpha-entmax attention weights must be exactly zero. | |
| entmax-1.5 (alpha-entmax) produces truly sparse distributions. Sparsity | |
| increases with score spread — after gradient descent the memory keys will | |
| be unit-scale. We use unit-norm memory to represent the operating condition | |
| (not the tiny 0.01-init default, which would produce near-uniform scores | |
| and thus lower sparsity by design). | |
| """ | |
| D, N = 64, 1024 | |
| from hydra.engram import GPUEngram | |
| m = GPUEngram(d_model=D, n_columns=N) | |
| # Re-initialise memory to unit-norm scale — representative of trained weights. | |
| with torch.no_grad(): | |
| m.memory.data = torch.nn.functional.normalize( | |
| torch.randn(N, D), dim=-1 | |
| ) | |
| m.eval() | |
| x = torch.randn(4, 32, D) | |
| token_ids = torch.randint(0, 500, (4, 32)) | |
| # Replicate the retrieve path to inspect weights directly. | |
| with torch.no_grad(): | |
| scores = x @ m.memory.T # (4, 32, N) | |
| try: | |
| from entmax import entmax15 | |
| weights = entmax15(scores, dim=-1) | |
| except ImportError: | |
| # top-k softmax fallback: k=32, guaranteed ≥ 96.9% zeros at N=1024 | |
| k = 32 | |
| topk_vals, topk_idx = scores.topk(k, dim=-1) | |
| topk_w = torch.softmax(topk_vals, dim=-1) | |
| weights = torch.zeros_like(scores) | |
| weights.scatter_(-1, topk_idx, topk_w) | |
| zero_fraction = (weights == 0).float().mean().item() | |
| assert zero_fraction >= 0.95, ( | |
| f"Expected >= 95% sparsity in attention weights, got {zero_fraction:.3f}" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # test_no_nan_on_zero_input | |
| # --------------------------------------------------------------------------- | |
| def test_no_nan_on_zero_input(): | |
| """All-zero input must produce a finite output (no NaN/Inf from entmax).""" | |
| D, N = 64, 256 | |
| m = _make_engram(d_model=D, n_columns=N) | |
| m.eval() | |
| x = torch.zeros(1, 8, D) | |
| token_ids = torch.zeros(1, 8, dtype=torch.long) | |
| out, hit_rate = m(x, token_ids) | |
| assert torch.isfinite(out).all(), "Output contains NaN or Inf on zero input" | |
| assert torch.isfinite(hit_rate), "hit_rate is NaN or Inf on zero input" | |
| # --------------------------------------------------------------------------- | |
| # test_scales_to_32k | |
| # --------------------------------------------------------------------------- | |
| def test_scales_to_32k(): | |
| """n_columns=32768 must run on CPU without OOM and return correct shape.""" | |
| D, N = 128, 32768 | |
| from hydra.engram import GPUEngram | |
| m = GPUEngram(d_model=D, n_columns=N) | |
| m.eval() | |
| x = torch.randn(1, 64, D) | |
| token_ids = torch.randint(0, 1000, (1, 64)) | |
| out, hit_rate = m(x, token_ids) | |
| assert out.shape == (1, 64, D), f"Expected (1, 64, {D}), got {out.shape}" | |
| assert torch.isfinite(out).all(), "Output contains NaN/Inf at n_columns=32768" | |
| # --------------------------------------------------------------------------- | |
| # Bonus: hebbian_boost=False (default) does NOT update memory.data during train | |
| # --------------------------------------------------------------------------- | |
| def test_hebbian_off_by_default(): | |
| """With default hebbian_boost=False, memory.data is unchanged after train forward.""" | |
| D, N = 32, 64 | |
| m = _make_engram(d_model=D, n_columns=N, hebbian_boost=False) | |
| m.train() | |
| before = m.memory.data.clone() | |
| x = torch.randn(2, 4, D) | |
| token_ids = torch.randint(0, 50, (2, 4)) | |
| m(x, token_ids) | |
| after = m.memory.data | |
| assert torch.equal(before, after), ( | |
| "memory.data was mutated during forward but hebbian_boost=False" | |
| ) | |
| def test_hebbian_on_updates_memory(): | |
| """With hebbian_boost=True, memory.data changes after train forward.""" | |
| D, N = 32, 64 | |
| from hydra.engram import GPUEngram | |
| m = GPUEngram(d_model=D, n_columns=N, hebbian_boost=True) | |
| m.train() | |
| before = m.memory.data.clone() | |
| x = torch.randn(2, 4, D) | |
| token_ids = torch.randint(0, 50, (2, 4)) | |
| m(x, token_ids) | |
| after = m.memory.data | |
| assert not torch.equal(before, after), ( | |
| "memory.data was NOT mutated during forward but hebbian_boost=True" | |
| ) | |