File size: 5,083 Bytes
2ece486 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """
EIGENGRAM test suite — no model calls, pure format verification.
"""
from __future__ import annotations
import os
import struct
import pytest
import torch
from kvcos.engram.format import (
EigramDecoder,
EigramEncoder,
EIGENGRAM_MAGIC,
EIGENGRAM_VERSION,
)
BASIS_PATH = "results/corpus_basis_fcdb_v2.pt"
@pytest.fixture(scope="module")
def basis():
if not os.path.exists(BASIS_PATH):
pytest.skip("FCDB v2 basis not built yet")
return torch.load(BASIS_PATH, weights_only=False)
@pytest.fixture(scope="module")
def sample_cert(basis):
enc = EigramEncoder()
R = basis["basis"].shape[0]
return enc.encode(
vec_perdoc=torch.randn(R),
vec_fcdb=torch.randn(R),
joint_center=basis["joint_center"],
corpus_hash="a" * 32,
model_id="Llama-3.1-8B",
basis_rank=R,
n_corpus=200,
layer_range=(8, 24),
context_len=512,
l2_norm=1.234,
scs=0.42,
margin_proof=0.013,
task_description="Test document for transformer attention.",
cache_id="test-doc-001",
)
class TestFormat:
def test_magic_present(self, sample_cert: bytes) -> None:
assert sample_cert[:4] == EIGENGRAM_MAGIC
def test_version_byte(self, sample_cert: bytes) -> None:
assert struct.unpack_from("<B", sample_cert, 4)[0] == EIGENGRAM_VERSION
def test_minimum_size(self, sample_cert: bytes, basis) -> None:
R = basis["basis"].shape[0]
min_size = 99 + R * 2 + R * 2 + 128 * 2
assert len(sample_cert) >= min_size
def test_file_size_reasonable(self, sample_cert: bytes) -> None:
assert len(sample_cert) < 2048
class TestRoundTrip:
def test_model_id(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["model_id"] == "Llama-3.1-8B"
def test_basis_rank(self, sample_cert: bytes, basis) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["basis_rank"] == basis["basis"].shape[0]
def test_vec_perdoc_shape(self, sample_cert: bytes, basis) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["vec_perdoc"].shape == (basis["basis"].shape[0],)
def test_vec_fcdb_shape(self, sample_cert: bytes, basis) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["vec_fcdb"].shape == (basis["basis"].shape[0],)
def test_joint_center_shape(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["joint_center"].shape == (128,)
def test_scs(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert abs(rec["scs"] - 0.42) < 0.01
def test_margin_proof(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert abs(rec["margin_proof"] - 0.013) < 0.001
def test_task_description(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert "transformer" in rec["task_description"]
def test_cache_id(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["cache_id"] == "test-doc-001"
def test_layer_range(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["layer_range"] == (8, 24)
def test_n_corpus(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["n_corpus"] == 200
def test_context_len(self, sample_cert: bytes) -> None:
rec = EigramDecoder().decode(sample_cert)
assert rec["context_len"] == 512
def test_float16_cosine_preserved(self, basis) -> None:
enc = EigramEncoder()
R = basis["basis"].shape[0]
v = torch.randn(R)
v = v / v.norm()
cert = enc.encode(
vec_perdoc=v, vec_fcdb=v,
joint_center=basis["joint_center"],
corpus_hash="a" * 32, model_id="test",
basis_rank=R, n_corpus=200,
layer_range=(8, 24), context_len=0,
l2_norm=1.0, scs=0.5, margin_proof=0.0,
task_description="cosine test", cache_id="cos",
)
rec = EigramDecoder().decode(cert)
cos = torch.nn.functional.cosine_similarity(
v.unsqueeze(0), rec["vec_perdoc"].unsqueeze(0)
).item()
assert cos > 0.999, f"Cosine after round-trip: {cos:.5f}"
class TestErrorHandling:
def test_bad_magic_raises(self) -> None:
bad = b"XXXX" + b"\x00" * 200
with pytest.raises(ValueError, match="magic"):
EigramDecoder().decode(bad)
def test_wrong_version_raises(self, sample_cert: bytes) -> None:
data = bytearray(sample_cert)
data[4] = 99
with pytest.raises(ValueError, match="version"):
EigramDecoder().decode(bytes(data))
def test_truncated_raises(self, sample_cert: bytes) -> None:
with pytest.raises(Exception):
EigramDecoder().decode(sample_cert[:20])
|