File size: 5,083 Bytes
2ece486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
EIGENGRAM test suite — no model calls, pure format verification.
"""

from __future__ import annotations

import os
import struct

import pytest
import torch

from kvcos.engram.format import (
    EigramDecoder,
    EigramEncoder,
    EIGENGRAM_MAGIC,
    EIGENGRAM_VERSION,
)

BASIS_PATH = "results/corpus_basis_fcdb_v2.pt"


@pytest.fixture(scope="module")
def basis():
    if not os.path.exists(BASIS_PATH):
        pytest.skip("FCDB v2 basis not built yet")
    return torch.load(BASIS_PATH, weights_only=False)


@pytest.fixture(scope="module")
def sample_cert(basis):
    enc = EigramEncoder()
    R = basis["basis"].shape[0]
    return enc.encode(
        vec_perdoc=torch.randn(R),
        vec_fcdb=torch.randn(R),
        joint_center=basis["joint_center"],
        corpus_hash="a" * 32,
        model_id="Llama-3.1-8B",
        basis_rank=R,
        n_corpus=200,
        layer_range=(8, 24),
        context_len=512,
        l2_norm=1.234,
        scs=0.42,
        margin_proof=0.013,
        task_description="Test document for transformer attention.",
        cache_id="test-doc-001",
    )


class TestFormat:
    def test_magic_present(self, sample_cert: bytes) -> None:
        assert sample_cert[:4] == EIGENGRAM_MAGIC

    def test_version_byte(self, sample_cert: bytes) -> None:
        assert struct.unpack_from("<B", sample_cert, 4)[0] == EIGENGRAM_VERSION

    def test_minimum_size(self, sample_cert: bytes, basis) -> None:
        R = basis["basis"].shape[0]
        min_size = 99 + R * 2 + R * 2 + 128 * 2
        assert len(sample_cert) >= min_size

    def test_file_size_reasonable(self, sample_cert: bytes) -> None:
        assert len(sample_cert) < 2048


class TestRoundTrip:
    def test_model_id(self, sample_cert: bytes) -> None:
        rec = EigramDecoder().decode(sample_cert)
        assert rec["model_id"] == "Llama-3.1-8B"

    def test_basis_rank(self, sample_cert: bytes, basis) -> None:
        rec = EigramDecoder().decode(sample_cert)
        assert rec["basis_rank"] == basis["basis"].shape[0]

    def test_vec_perdoc_shape(self, sample_cert: bytes, basis) -> None:
        rec = EigramDecoder().decode(sample_cert)
        assert rec["vec_perdoc"].shape == (basis["basis"].shape[0],)

    def test_vec_fcdb_shape(self, sample_cert: bytes, basis) -> None:
        rec = EigramDecoder().decode(sample_cert)
        assert rec["vec_fcdb"].shape == (basis["basis"].shape[0],)

    def test_joint_center_shape(self, sample_cert: bytes) -> None:
        rec = EigramDecoder().decode(sample_cert)
        assert rec["joint_center"].shape == (128,)

    def test_scs(self, sample_cert: bytes) -> None:
        rec = EigramDecoder().decode(sample_cert)
        assert abs(rec["scs"] - 0.42) < 0.01

    def test_margin_proof(self, sample_cert: bytes) -> None:
        rec = EigramDecoder().decode(sample_cert)
        assert abs(rec["margin_proof"] - 0.013) < 0.001

    def test_task_description(self, sample_cert: bytes) -> None:
        rec = EigramDecoder().decode(sample_cert)
        assert "transformer" in rec["task_description"]

    def test_cache_id(self, sample_cert: bytes) -> None:
        rec = EigramDecoder().decode(sample_cert)
        assert rec["cache_id"] == "test-doc-001"

    def test_layer_range(self, sample_cert: bytes) -> None:
        rec = EigramDecoder().decode(sample_cert)
        assert rec["layer_range"] == (8, 24)

    def test_n_corpus(self, sample_cert: bytes) -> None:
        rec = EigramDecoder().decode(sample_cert)
        assert rec["n_corpus"] == 200

    def test_context_len(self, sample_cert: bytes) -> None:
        rec = EigramDecoder().decode(sample_cert)
        assert rec["context_len"] == 512

    def test_float16_cosine_preserved(self, basis) -> None:
        enc = EigramEncoder()
        R = basis["basis"].shape[0]
        v = torch.randn(R)
        v = v / v.norm()
        cert = enc.encode(
            vec_perdoc=v, vec_fcdb=v,
            joint_center=basis["joint_center"],
            corpus_hash="a" * 32, model_id="test",
            basis_rank=R, n_corpus=200,
            layer_range=(8, 24), context_len=0,
            l2_norm=1.0, scs=0.5, margin_proof=0.0,
            task_description="cosine test", cache_id="cos",
        )
        rec = EigramDecoder().decode(cert)
        cos = torch.nn.functional.cosine_similarity(
            v.unsqueeze(0), rec["vec_perdoc"].unsqueeze(0)
        ).item()
        assert cos > 0.999, f"Cosine after round-trip: {cos:.5f}"


class TestErrorHandling:
    def test_bad_magic_raises(self) -> None:
        bad = b"XXXX" + b"\x00" * 200
        with pytest.raises(ValueError, match="magic"):
            EigramDecoder().decode(bad)

    def test_wrong_version_raises(self, sample_cert: bytes) -> None:
        data = bytearray(sample_cert)
        data[4] = 99
        with pytest.raises(ValueError, match="version"):
            EigramDecoder().decode(bytes(data))

    def test_truncated_raises(self, sample_cert: bytes) -> None:
        with pytest.raises(Exception):
            EigramDecoder().decode(sample_cert[:20])