"""Tests for GPUEngram Sparse Modern Hopfield retrieval path. Tests are written first (TDD) against the new matmul-based retrieval. Run with: pytest tests/test_engram.py -v """ from __future__ import annotations import math import pytest import torch import torch.nn as nn # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_engram(d_model: int = 64, n_columns: int = 1024, hebbian_boost: bool = False): from hydra.engram import GPUEngram m = GPUEngram(d_model=d_model, n_columns=n_columns, hebbian_boost=hebbian_boost) m.eval() return m # --------------------------------------------------------------------------- # test_forward_shape # --------------------------------------------------------------------------- def test_forward_shape(): """Output tensor matches input shape; hit_rate is a scalar.""" B, T, D = 2, 16, 64 m = _make_engram(d_model=D, n_columns=1024) x = torch.randn(B, T, D) token_ids = torch.randint(0, 1000, (B, T)) out, hit_rate = m(x, token_ids) assert out.shape == (B, T, D), f"Expected ({B},{T},{D}), got {out.shape}" assert hit_rate.ndim == 0, f"hit_rate should be scalar, got shape {hit_rate.shape}" # --------------------------------------------------------------------------- # test_gradient_flow # --------------------------------------------------------------------------- def test_gradient_flow(): """Backprop through the Hopfield matmul path must reach self.memory.grad. The old scatter-gather path used self.memory[indices] which DID produce gradients only for indexed rows. The new path (scores = x @ memory.T then weights @ memory) creates a full matmul, so every column gets a non-zero gradient signal (on a random batch where all keys are attended to). """ D, N = 64, 128 m = _make_engram(d_model=D, n_columns=N) m.train() x = torch.randn(2, 8, D, requires_grad=True) token_ids = torch.randint(0, 100, (2, 8)) out, _ = m(x, token_ids) loss = out.sum() loss.backward() assert m.memory.grad is not None, "self.memory.grad must be non-None after backward" assert m.memory.grad.abs().sum() > 0, "self.memory.grad must have non-zero entries" # --------------------------------------------------------------------------- # test_sparsity # --------------------------------------------------------------------------- def test_sparsity(): """At least 95% of alpha-entmax attention weights must be exactly zero. entmax-1.5 (alpha-entmax) produces truly sparse distributions. Sparsity increases with score spread — after gradient descent the memory keys will be unit-scale. We use unit-norm memory to represent the operating condition (not the tiny 0.01-init default, which would produce near-uniform scores and thus lower sparsity by design). """ D, N = 64, 1024 from hydra.engram import GPUEngram m = GPUEngram(d_model=D, n_columns=N) # Re-initialise memory to unit-norm scale — representative of trained weights. with torch.no_grad(): m.memory.data = torch.nn.functional.normalize( torch.randn(N, D), dim=-1 ) m.eval() x = torch.randn(4, 32, D) token_ids = torch.randint(0, 500, (4, 32)) # Replicate the retrieve path to inspect weights directly. with torch.no_grad(): scores = x @ m.memory.T # (4, 32, N) try: from entmax import entmax15 weights = entmax15(scores, dim=-1) except ImportError: # top-k softmax fallback: k=32, guaranteed ≥ 96.9% zeros at N=1024 k = 32 topk_vals, topk_idx = scores.topk(k, dim=-1) topk_w = torch.softmax(topk_vals, dim=-1) weights = torch.zeros_like(scores) weights.scatter_(-1, topk_idx, topk_w) zero_fraction = (weights == 0).float().mean().item() assert zero_fraction >= 0.95, ( f"Expected >= 95% sparsity in attention weights, got {zero_fraction:.3f}" ) # --------------------------------------------------------------------------- # test_no_nan_on_zero_input # --------------------------------------------------------------------------- def test_no_nan_on_zero_input(): """All-zero input must produce a finite output (no NaN/Inf from entmax).""" D, N = 64, 256 m = _make_engram(d_model=D, n_columns=N) m.eval() x = torch.zeros(1, 8, D) token_ids = torch.zeros(1, 8, dtype=torch.long) out, hit_rate = m(x, token_ids) assert torch.isfinite(out).all(), "Output contains NaN or Inf on zero input" assert torch.isfinite(hit_rate), "hit_rate is NaN or Inf on zero input" # --------------------------------------------------------------------------- # test_scales_to_32k # --------------------------------------------------------------------------- def test_scales_to_32k(): """n_columns=32768 must run on CPU without OOM and return correct shape.""" D, N = 128, 32768 from hydra.engram import GPUEngram m = GPUEngram(d_model=D, n_columns=N) m.eval() x = torch.randn(1, 64, D) token_ids = torch.randint(0, 1000, (1, 64)) out, hit_rate = m(x, token_ids) assert out.shape == (1, 64, D), f"Expected (1, 64, {D}), got {out.shape}" assert torch.isfinite(out).all(), "Output contains NaN/Inf at n_columns=32768" # --------------------------------------------------------------------------- # Bonus: hebbian_boost=False (default) does NOT update memory.data during train # --------------------------------------------------------------------------- def test_hebbian_off_by_default(): """With default hebbian_boost=False, memory.data is unchanged after train forward.""" D, N = 32, 64 m = _make_engram(d_model=D, n_columns=N, hebbian_boost=False) m.train() before = m.memory.data.clone() x = torch.randn(2, 4, D) token_ids = torch.randint(0, 50, (2, 4)) m(x, token_ids) after = m.memory.data assert torch.equal(before, after), ( "memory.data was mutated during forward but hebbian_boost=False" ) def test_hebbian_on_updates_memory(): """With hebbian_boost=True, memory.data changes after train forward.""" D, N = 32, 64 from hydra.engram import GPUEngram m = GPUEngram(d_model=D, n_columns=N, hebbian_boost=True) m.train() before = m.memory.data.clone() x = torch.randn(2, 4, D) token_ids = torch.randint(0, 50, (2, 4)) m(x, token_ids) after = m.memory.data assert not torch.equal(before, after), ( "memory.data was NOT mutated during forward but hebbian_boost=True" )