engram / tests /test_eigengram.py
eigengram's picture
test: upload 220 tests
2ece486 verified
"""
EIGENGRAM test suite — no model calls, pure format verification.
"""
from __future__ import annotations
import os
import struct
import pytest
import torch
from kvcos.engram.format import (
EigramDecoder,
EigramEncoder,
EIGENGRAM_MAGIC,
EIGENGRAM_VERSION,
)
BASIS_PATH = "results/corpus_basis_fcdb_v2.pt"
@pytest.fixture(scope="module")
def basis():
if not os.path.exists(BASIS_PATH):
pytest.skip("FCDB v2 basis not built yet")
return torch.load(BASIS_PATH, weights_only=False)
@pytest.fixture(scope="module")
def sample_cert(basis):
enc = EigramEncoder()
R = basis["basis"].shape[0]
return enc.encode(
vec_perdoc=torch.randn(R),
vec_fcdb=torch.randn(R),
joint_center=basis["joint_center"],
corpus_hash="a" * 32,
model_id="Llama-3.1-8B",
basis_rank=R,
n_corpus=200,
layer_range=(8, 24),
context_len=512,
l2_norm=1.234,
scs=0.42,
margin_proof=0.013,
task_description="Test document for transformer attention.",
cache_id="test-doc-001",
)
class TestFormat:
def test_magic_present(self, sample_cert: bytes) -> None:
assert sample_cert[:4] == EIGENGRAM_MAGIC
def test_version_byte(self, sample_cert: bytes) -> None:
assert struct.unpack_from("<B", sample_cert, 4)[0] == EIGENGRAM_VERSION
def test_minimum_size(self, sample_cert: bytes, basis) -> None:
R = basis["basis"].shape[0]
min_size = 99 + R * 2 + R * 2 + 128 * 2
assert len(sample_cert) >= min_size
def test_file_size_reasonable(self, sample_cert: bytes) -> None:
assert len(sample_cert) < 2048
class TestRoundTrip:
def test_model_id(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["model_id"] == "Llama-3.1-8B"
def test_basis_rank(self, sample_cert: bytes, basis) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["basis_rank"] == basis["basis"].shape[0]
def test_vec_perdoc_shape(self, sample_cert: bytes, basis) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["vec_perdoc"].shape == (basis["basis"].shape[0],)
def test_vec_fcdb_shape(self, sample_cert: bytes, basis) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["vec_fcdb"].shape == (basis["basis"].shape[0],)
def test_joint_center_shape(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["joint_center"].shape == (128,)
def test_scs(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert abs(rec["scs"] - 0.42) < 0.01
def test_margin_proof(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert abs(rec["margin_proof"] - 0.013) < 0.001
def test_task_description(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert "transformer" in rec["task_description"]
def test_cache_id(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["cache_id"] == "test-doc-001"
def test_layer_range(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["layer_range"] == (8, 24)
def test_n_corpus(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["n_corpus"] == 200
def test_context_len(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["context_len"] == 512
def test_float16_cosine_preserved(self, basis) -> None:
enc = EigramEncoder()
R = basis["basis"].shape[0]
v = torch.randn(R)
v = v / v.norm()
cert = enc.encode(
vec_perdoc=v, vec_fcdb=v,
joint_center=basis["joint_center"],
corpus_hash="a" * 32, model_id="test",
basis_rank=R, n_corpus=200,
layer_range=(8, 24), context_len=0,
l2_norm=1.0, scs=0.5, margin_proof=0.0,
task_description="cosine test", cache_id="cos",
)
rec = EigramDecoder().decode(cert)
cos = torch.nn.functional.cosine_similarity(
v.unsqueeze(0), rec["vec_perdoc"].unsqueeze(0)
).item()
assert cos > 0.999, f"Cosine after round-trip: {cos:.5f}"
class TestErrorHandling:
def test_bad_magic_raises(self) -> None:
bad = b"XXXX" + b"\x00" * 200
with pytest.raises(ValueError, match="magic"):
EigramDecoder().decode(bad)
def test_wrong_version_raises(self, sample_cert: bytes) -> None:
data = bytearray(sample_cert)
data[4] = 99
with pytest.raises(ValueError, match="version"):
EigramDecoder().decode(bytes(data))
def test_truncated_raises(self, sample_cert: bytes) -> None:
with pytest.raises(Exception):
EigramDecoder().decode(sample_cert[:20])