File size: 6,725 Bytes
c475135 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | """Tests for GPUEngram Sparse Modern Hopfield retrieval path.
Tests are written first (TDD) against the new matmul-based retrieval.
Run with: pytest tests/test_engram.py -v
"""
from __future__ import annotations
import math
import pytest
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_engram(d_model: int = 64, n_columns: int = 1024, hebbian_boost: bool = False):
from hydra.engram import GPUEngram
m = GPUEngram(d_model=d_model, n_columns=n_columns, hebbian_boost=hebbian_boost)
m.eval()
return m
# ---------------------------------------------------------------------------
# test_forward_shape
# ---------------------------------------------------------------------------
def test_forward_shape():
"""Output tensor matches input shape; hit_rate is a scalar."""
B, T, D = 2, 16, 64
m = _make_engram(d_model=D, n_columns=1024)
x = torch.randn(B, T, D)
token_ids = torch.randint(0, 1000, (B, T))
out, hit_rate = m(x, token_ids)
assert out.shape == (B, T, D), f"Expected ({B},{T},{D}), got {out.shape}"
assert hit_rate.ndim == 0, f"hit_rate should be scalar, got shape {hit_rate.shape}"
# ---------------------------------------------------------------------------
# test_gradient_flow
# ---------------------------------------------------------------------------
def test_gradient_flow():
"""Backprop through the Hopfield matmul path must reach self.memory.grad.
The old scatter-gather path used self.memory[indices] which DID produce
gradients only for indexed rows. The new path (scores = x @ memory.T then
weights @ memory) creates a full matmul, so every column gets a non-zero
gradient signal (on a random batch where all keys are attended to).
"""
D, N = 64, 128
m = _make_engram(d_model=D, n_columns=N)
m.train()
x = torch.randn(2, 8, D, requires_grad=True)
token_ids = torch.randint(0, 100, (2, 8))
out, _ = m(x, token_ids)
loss = out.sum()
loss.backward()
assert m.memory.grad is not None, "self.memory.grad must be non-None after backward"
assert m.memory.grad.abs().sum() > 0, "self.memory.grad must have non-zero entries"
# ---------------------------------------------------------------------------
# test_sparsity
# ---------------------------------------------------------------------------
def test_sparsity():
"""At least 95% of alpha-entmax attention weights must be exactly zero.
entmax-1.5 (alpha-entmax) produces truly sparse distributions. Sparsity
increases with score spread — after gradient descent the memory keys will
be unit-scale. We use unit-norm memory to represent the operating condition
(not the tiny 0.01-init default, which would produce near-uniform scores
and thus lower sparsity by design).
"""
D, N = 64, 1024
from hydra.engram import GPUEngram
m = GPUEngram(d_model=D, n_columns=N)
# Re-initialise memory to unit-norm scale — representative of trained weights.
with torch.no_grad():
m.memory.data = torch.nn.functional.normalize(
torch.randn(N, D), dim=-1
)
m.eval()
x = torch.randn(4, 32, D)
token_ids = torch.randint(0, 500, (4, 32))
# Replicate the retrieve path to inspect weights directly.
with torch.no_grad():
scores = x @ m.memory.T # (4, 32, N)
try:
from entmax import entmax15
weights = entmax15(scores, dim=-1)
except ImportError:
# top-k softmax fallback: k=32, guaranteed ≥ 96.9% zeros at N=1024
k = 32
topk_vals, topk_idx = scores.topk(k, dim=-1)
topk_w = torch.softmax(topk_vals, dim=-1)
weights = torch.zeros_like(scores)
weights.scatter_(-1, topk_idx, topk_w)
zero_fraction = (weights == 0).float().mean().item()
assert zero_fraction >= 0.95, (
f"Expected >= 95% sparsity in attention weights, got {zero_fraction:.3f}"
)
# ---------------------------------------------------------------------------
# test_no_nan_on_zero_input
# ---------------------------------------------------------------------------
def test_no_nan_on_zero_input():
"""All-zero input must produce a finite output (no NaN/Inf from entmax)."""
D, N = 64, 256
m = _make_engram(d_model=D, n_columns=N)
m.eval()
x = torch.zeros(1, 8, D)
token_ids = torch.zeros(1, 8, dtype=torch.long)
out, hit_rate = m(x, token_ids)
assert torch.isfinite(out).all(), "Output contains NaN or Inf on zero input"
assert torch.isfinite(hit_rate), "hit_rate is NaN or Inf on zero input"
# ---------------------------------------------------------------------------
# test_scales_to_32k
# ---------------------------------------------------------------------------
def test_scales_to_32k():
"""n_columns=32768 must run on CPU without OOM and return correct shape."""
D, N = 128, 32768
from hydra.engram import GPUEngram
m = GPUEngram(d_model=D, n_columns=N)
m.eval()
x = torch.randn(1, 64, D)
token_ids = torch.randint(0, 1000, (1, 64))
out, hit_rate = m(x, token_ids)
assert out.shape == (1, 64, D), f"Expected (1, 64, {D}), got {out.shape}"
assert torch.isfinite(out).all(), "Output contains NaN/Inf at n_columns=32768"
# ---------------------------------------------------------------------------
# Bonus: hebbian_boost=False (default) does NOT update memory.data during train
# ---------------------------------------------------------------------------
def test_hebbian_off_by_default():
"""With default hebbian_boost=False, memory.data is unchanged after train forward."""
D, N = 32, 64
m = _make_engram(d_model=D, n_columns=N, hebbian_boost=False)
m.train()
before = m.memory.data.clone()
x = torch.randn(2, 4, D)
token_ids = torch.randint(0, 50, (2, 4))
m(x, token_ids)
after = m.memory.data
assert torch.equal(before, after), (
"memory.data was mutated during forward but hebbian_boost=False"
)
def test_hebbian_on_updates_memory():
"""With hebbian_boost=True, memory.data changes after train forward."""
D, N = 32, 64
from hydra.engram import GPUEngram
m = GPUEngram(d_model=D, n_columns=N, hebbian_boost=True)
m.train()
before = m.memory.data.clone()
x = torch.randn(2, 4, D)
token_ids = torch.randint(0, 50, (2, 4))
m(x, token_ids)
after = m.memory.data
assert not torch.equal(before, after), (
"memory.data was NOT mutated during forward but hebbian_boost=True"
)
|