icarus112's picture
Update Feather a10g-large training runtime image
c475135 verified
"""Tests for GPUEngram Sparse Modern Hopfield retrieval path.
Tests are written first (TDD) against the new matmul-based retrieval.
Run with: pytest tests/test_engram.py -v
"""
from __future__ import annotations
import math
import pytest
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_engram(d_model: int = 64, n_columns: int = 1024, hebbian_boost: bool = False):
from hydra.engram import GPUEngram
m = GPUEngram(d_model=d_model, n_columns=n_columns, hebbian_boost=hebbian_boost)
m.eval()
return m
# ---------------------------------------------------------------------------
# test_forward_shape
# ---------------------------------------------------------------------------
def test_forward_shape():
"""Output tensor matches input shape; hit_rate is a scalar."""
B, T, D = 2, 16, 64
m = _make_engram(d_model=D, n_columns=1024)
x = torch.randn(B, T, D)
token_ids = torch.randint(0, 1000, (B, T))
out, hit_rate = m(x, token_ids)
assert out.shape == (B, T, D), f"Expected ({B},{T},{D}), got {out.shape}"
assert hit_rate.ndim == 0, f"hit_rate should be scalar, got shape {hit_rate.shape}"
# ---------------------------------------------------------------------------
# test_gradient_flow
# ---------------------------------------------------------------------------
def test_gradient_flow():
"""Backprop through the Hopfield matmul path must reach self.memory.grad.
The old scatter-gather path used self.memory[indices] which DID produce
gradients only for indexed rows. The new path (scores = x @ memory.T then
weights @ memory) creates a full matmul, so every column gets a non-zero
gradient signal (on a random batch where all keys are attended to).
"""
D, N = 64, 128
m = _make_engram(d_model=D, n_columns=N)
m.train()
x = torch.randn(2, 8, D, requires_grad=True)
token_ids = torch.randint(0, 100, (2, 8))
out, _ = m(x, token_ids)
loss = out.sum()
loss.backward()
assert m.memory.grad is not None, "self.memory.grad must be non-None after backward"
assert m.memory.grad.abs().sum() > 0, "self.memory.grad must have non-zero entries"
# ---------------------------------------------------------------------------
# test_sparsity
# ---------------------------------------------------------------------------
def test_sparsity():
"""At least 95% of alpha-entmax attention weights must be exactly zero.
entmax-1.5 (alpha-entmax) produces truly sparse distributions. Sparsity
increases with score spread — after gradient descent the memory keys will
be unit-scale. We use unit-norm memory to represent the operating condition
(not the tiny 0.01-init default, which would produce near-uniform scores
and thus lower sparsity by design).
"""
D, N = 64, 1024
from hydra.engram import GPUEngram
m = GPUEngram(d_model=D, n_columns=N)
# Re-initialise memory to unit-norm scale — representative of trained weights.
with torch.no_grad():
m.memory.data = torch.nn.functional.normalize(
torch.randn(N, D), dim=-1
)
m.eval()
x = torch.randn(4, 32, D)
token_ids = torch.randint(0, 500, (4, 32))
# Replicate the retrieve path to inspect weights directly.
with torch.no_grad():
scores = x @ m.memory.T # (4, 32, N)
try:
from entmax import entmax15
weights = entmax15(scores, dim=-1)
except ImportError:
# top-k softmax fallback: k=32, guaranteed ≥ 96.9% zeros at N=1024
k = 32
topk_vals, topk_idx = scores.topk(k, dim=-1)
topk_w = torch.softmax(topk_vals, dim=-1)
weights = torch.zeros_like(scores)
weights.scatter_(-1, topk_idx, topk_w)
zero_fraction = (weights == 0).float().mean().item()
assert zero_fraction >= 0.95, (
f"Expected >= 95% sparsity in attention weights, got {zero_fraction:.3f}"
)
# ---------------------------------------------------------------------------
# test_no_nan_on_zero_input
# ---------------------------------------------------------------------------
def test_no_nan_on_zero_input():
"""All-zero input must produce a finite output (no NaN/Inf from entmax)."""
D, N = 64, 256
m = _make_engram(d_model=D, n_columns=N)
m.eval()
x = torch.zeros(1, 8, D)
token_ids = torch.zeros(1, 8, dtype=torch.long)
out, hit_rate = m(x, token_ids)
assert torch.isfinite(out).all(), "Output contains NaN or Inf on zero input"
assert torch.isfinite(hit_rate), "hit_rate is NaN or Inf on zero input"
# ---------------------------------------------------------------------------
# test_scales_to_32k
# ---------------------------------------------------------------------------
def test_scales_to_32k():
"""n_columns=32768 must run on CPU without OOM and return correct shape."""
D, N = 128, 32768
from hydra.engram import GPUEngram
m = GPUEngram(d_model=D, n_columns=N)
m.eval()
x = torch.randn(1, 64, D)
token_ids = torch.randint(0, 1000, (1, 64))
out, hit_rate = m(x, token_ids)
assert out.shape == (1, 64, D), f"Expected (1, 64, {D}), got {out.shape}"
assert torch.isfinite(out).all(), "Output contains NaN/Inf at n_columns=32768"
# ---------------------------------------------------------------------------
# Bonus: hebbian_boost=False (default) does NOT update memory.data during train
# ---------------------------------------------------------------------------
def test_hebbian_off_by_default():
"""With default hebbian_boost=False, memory.data is unchanged after train forward."""
D, N = 32, 64
m = _make_engram(d_model=D, n_columns=N, hebbian_boost=False)
m.train()
before = m.memory.data.clone()
x = torch.randn(2, 4, D)
token_ids = torch.randint(0, 50, (2, 4))
m(x, token_ids)
after = m.memory.data
assert torch.equal(before, after), (
"memory.data was mutated during forward but hebbian_boost=False"
)
def test_hebbian_on_updates_memory():
"""With hebbian_boost=True, memory.data changes after train forward."""
D, N = 32, 64
from hydra.engram import GPUEngram
m = GPUEngram(d_model=D, n_columns=N, hebbian_boost=True)
m.train()
before = m.memory.data.clone()
x = torch.randn(2, 4, D)
token_ids = torch.randint(0, 50, (2, 4))
m(x, token_ids)
after = m.memory.data
assert not torch.equal(before, after), (
"memory.data was NOT mutated during forward but hebbian_boost=True"
)