eigengram commited on
Commit
2ece486
·
verified ·
1 Parent(s): 19d71eb

test: upload 220 tests

Browse files
tests/__init__.py ADDED
File without changes
tests/conftest.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — Test Fixtures
3
+
4
+
5
+ Shared pytest fixtures for all test modules.
6
+ Provides synthetic KV cache tensors at correct shapes,
7
+ temp directories, and model specs.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from pathlib import Path
13
+
14
+ import pytest
15
+ import torch
16
+
17
+ from kvcos.core.cache_spec import GEMMA_4_26B_A4B, LLAMA_3_1_8B, PHI_3_MINI
18
+ from kvcos.core.types import AttentionType, CacheSection, ModelCacheSpec
19
+
20
+
21
+ @pytest.fixture
22
+ def llama_spec() -> ModelCacheSpec:
23
+ """Llama 3.1 8B model spec."""
24
+ return LLAMA_3_1_8B
25
+
26
+
27
+ @pytest.fixture
28
+ def phi3_spec() -> ModelCacheSpec:
29
+ """Phi-3-Mini model spec."""
30
+ return PHI_3_MINI
31
+
32
+
33
+ @pytest.fixture
34
+ def gemma4_spec() -> ModelCacheSpec:
35
+ """Gemma 4 26B-A4B ISWA model spec."""
36
+ return GEMMA_4_26B_A4B
37
+
38
+
39
+ @pytest.fixture
40
+ def tmp_data_dir(tmp_path: Path) -> Path:
41
+ """Temporary data directory for storage tests."""
42
+ data_dir = tmp_path / "engram_data"
43
+ data_dir.mkdir()
44
+ return data_dir
45
+
46
+
47
+ @pytest.fixture
48
+ def tmp_index_dir(tmp_path: Path) -> Path:
49
+ """Temporary directory for FAISS index persistence tests."""
50
+ index_dir = tmp_path / "engram_index"
51
+ index_dir.mkdir()
52
+ return index_dir
53
+
54
+
55
+ def make_synthetic_kv(
56
+ spec: ModelCacheSpec,
57
+ ctx_len: int = 256,
58
+ seed: int = 42,
59
+ ) -> tuple[torch.Tensor, torch.Tensor]:
60
+ """Create synthetic KV cache tensors with correct shapes.
61
+
62
+ Returns (keys, values) each [n_layers, n_kv_heads, ctx_len, head_dim].
63
+ Values are random but reproducible via seed.
64
+ """
65
+ torch.manual_seed(seed)
66
+ shape = (spec["n_layers"], spec["n_kv_heads"], ctx_len, spec["head_dim"])
67
+ keys = torch.randn(shape, dtype=torch.float16)
68
+ values = torch.randn(shape, dtype=torch.float16)
69
+ return keys, values
70
+
71
+
72
+ @pytest.fixture
73
+ def llama_kv_256(llama_spec: ModelCacheSpec) -> tuple[torch.Tensor, torch.Tensor]:
74
+ """Synthetic Llama 3.1 8B KV cache, 256 tokens.
75
+
76
+ Shape: [32, 8, 256, 128] for both keys and values.
77
+ """
78
+ return make_synthetic_kv(llama_spec, ctx_len=256)
79
+
80
+
81
+ @pytest.fixture
82
+ def llama_kv_1024(llama_spec: ModelCacheSpec) -> tuple[torch.Tensor, torch.Tensor]:
83
+ """Synthetic Llama 3.1 8B KV cache, 1024 tokens."""
84
+ return make_synthetic_kv(llama_spec, ctx_len=1024, seed=123)
85
+
86
+
87
+ @pytest.fixture
88
+ def phi3_kv_256(phi3_spec: ModelCacheSpec) -> tuple[torch.Tensor, torch.Tensor]:
89
+ """Synthetic Phi-3-Mini KV cache, 256 tokens.
90
+
91
+ Shape: [32, 32, 256, 96] for both keys and values.
92
+ """
93
+ return make_synthetic_kv(phi3_spec, ctx_len=256, seed=99)
94
+
95
+
96
+ # ── ISWA Fixtures ────────────────────────────────────────────────────────────
97
+
98
+
99
+ def make_synthetic_iswa_blob(
100
+ sections: tuple[CacheSection, ...],
101
+ n_cells: int = 4,
102
+ arch: str = "gemma4",
103
+ v_trans: bool = True,
104
+ seed: int = 42,
105
+ ) -> bytes:
106
+ """Build a synthetic ISWA blob with multiple KV cache sections.
107
+
108
+ Matches llama.cpp state blob format for ISWA models:
109
+ 1. Architecture string header
110
+ 2. n_stream = len(sections)
111
+ 3. Per stream: cell metadata + K/V data per layer
112
+
113
+ Args:
114
+ sections: Cache sections (e.g., global + SWA for Gemma 4).
115
+ n_cells: Number of KV cells per section.
116
+ arch: Architecture string in blob header.
117
+ v_trans: Whether V tensors are stored transposed.
118
+ seed: Random seed for reproducible data.
119
+ """
120
+ import struct
121
+
122
+ import numpy as np
123
+
124
+ from kvcos.core.blob_parser import GGML_TYPE_F16
125
+
126
+ rng = np.random.RandomState(seed)
127
+ parts: list[bytes] = []
128
+
129
+ # 1. Architecture string header
130
+ parts.append(struct.pack("<I", len(arch)))
131
+ parts.append(arch.encode("ascii"))
132
+
133
+ # 2. Stream count = number of cache sections
134
+ parts.append(struct.pack("<I", len(sections)))
135
+
136
+ # 3. Per-stream data
137
+ for section in sections:
138
+ n_embd_kv = section.n_kv_heads * section.head_dim
139
+ row_size = n_embd_kv * 2 # fp16
140
+
141
+ # Cell metadata
142
+ parts.append(struct.pack("<I", n_cells))
143
+ for i in range(n_cells):
144
+ parts.append(struct.pack("<i", i)) # pos
145
+ parts.append(struct.pack("<I", 1)) # n_seq_id = 1
146
+ parts.append(struct.pack("<i", 0)) # seq_id = 0
147
+
148
+ # Data section header
149
+ parts.append(struct.pack("<I", 1 if v_trans else 0))
150
+ parts.append(struct.pack("<I", section.n_layers))
151
+
152
+ # K layers
153
+ for _ in range(section.n_layers):
154
+ parts.append(struct.pack("<i", GGML_TYPE_F16))
155
+ parts.append(struct.pack("<Q", row_size))
156
+ data = rng.randn(n_cells * n_embd_kv).astype(np.float16)
157
+ parts.append(data.tobytes())
158
+
159
+ # V layers
160
+ for _ in range(section.n_layers):
161
+ parts.append(struct.pack("<i", GGML_TYPE_F16))
162
+ if v_trans:
163
+ parts.append(struct.pack("<I", 2)) # el_size (fp16)
164
+ parts.append(struct.pack("<I", n_embd_kv)) # n_embd_v_gqa
165
+ else:
166
+ parts.append(struct.pack("<Q", row_size))
167
+ data = rng.randn(n_cells * n_embd_kv).astype(np.float16)
168
+ parts.append(data.tobytes())
169
+
170
+ return b"".join(parts)
171
+
172
+
173
+ # Gemma 4 ISWA section constants (reverse-engineered)
174
+ GEMMA4_GLOBAL_SECTION = CacheSection(
175
+ attention_type=AttentionType.FULL,
176
+ n_layers=5,
177
+ n_kv_heads=2,
178
+ head_dim=512,
179
+ )
180
+
181
+ GEMMA4_SWA_SECTION = CacheSection(
182
+ attention_type=AttentionType.SLIDING,
183
+ n_layers=25,
184
+ n_kv_heads=8,
185
+ head_dim=256,
186
+ window_size=1024,
187
+ )
188
+
189
+ GEMMA4_SECTIONS = (GEMMA4_GLOBAL_SECTION, GEMMA4_SWA_SECTION)
190
+
191
+
192
+ @pytest.fixture
193
+ def gemma4_iswa_blob() -> bytes:
194
+ """Synthetic Gemma 4 ISWA blob with 2 sections, 4 cells."""
195
+ return make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
196
+
197
+
198
+ @pytest.fixture
199
+ def gemma4_iswa_blob_8cells() -> bytes:
200
+ """Synthetic Gemma 4 ISWA blob with 2 sections, 8 cells."""
201
+ return make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=8, seed=99)
tests/test_blob_parser.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — Blob Parser Tests
3
+ Tests for llama.cpp state blob → structured tensors (D1).
4
+ Uses synthetic blobs matching the real llama_state_get_data() format.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import struct
10
+
11
+ import numpy as np
12
+ import pytest
13
+ import torch
14
+
15
+ from kvcos.core.blob_parser import (
16
+ GGML_TYPE_F16,
17
+ BlobParseError,
18
+ ParsedKVCache,
19
+ parse_state_blob,
20
+ )
21
+
22
+
23
+ def _make_blob(
24
+ n_cells: int,
25
+ n_layers: int,
26
+ n_kv_heads: int,
27
+ head_dim: int,
28
+ arch: str = "llama",
29
+ v_trans: bool = True,
30
+ ) -> bytes:
31
+ """Build a synthetic blob matching llama_state_get_data() format."""
32
+ parts: list[bytes] = []
33
+
34
+ # 1. Architecture string header
35
+ parts.append(struct.pack("<I", len(arch)))
36
+ parts.append(arch.encode("ascii"))
37
+
38
+ # 2. KV stream header
39
+ parts.append(struct.pack("<I", 1)) # n_stream = 1
40
+ parts.append(struct.pack("<I", n_cells)) # cell_count
41
+
42
+ # 3. Cell metadata: (pos:i32, n_seq:u32, seq_id:i32) per cell
43
+ for i in range(n_cells):
44
+ parts.append(struct.pack("<i", i)) # pos
45
+ parts.append(struct.pack("<I", 1)) # n_seq_id = 1
46
+ parts.append(struct.pack("<i", 0)) # seq_id = 0
47
+
48
+ # 4. Data section header
49
+ parts.append(struct.pack("<I", 1 if v_trans else 0)) # v_trans
50
+ parts.append(struct.pack("<I", n_layers))
51
+
52
+ n_embd_kv = n_kv_heads * head_dim
53
+ row_size = n_embd_kv * 2 # fp16
54
+
55
+ # 5. K layers
56
+ for _ in range(n_layers):
57
+ parts.append(struct.pack("<i", GGML_TYPE_F16)) # type_k
58
+ parts.append(struct.pack("<Q", row_size)) # row_size_k
59
+ data = np.random.randn(n_cells * n_embd_kv).astype(np.float16)
60
+ parts.append(data.tobytes())
61
+
62
+ # 6. V layers
63
+ for _ in range(n_layers):
64
+ parts.append(struct.pack("<i", GGML_TYPE_F16)) # type_v
65
+ if v_trans:
66
+ parts.append(struct.pack("<I", 2)) # el_size (fp16)
67
+ parts.append(struct.pack("<I", n_embd_kv)) # n_embd_v_gqa
68
+ else:
69
+ parts.append(struct.pack("<Q", row_size)) # row_size_v
70
+ data = np.random.randn(n_cells * n_embd_kv).astype(np.float16)
71
+ parts.append(data.tobytes())
72
+
73
+ return b"".join(parts)
74
+
75
+
76
+ class TestBlobParser:
77
+ """Parse synthetic blobs in real llama_state_get_data format."""
78
+
79
+ def test_parse_shape(self) -> None:
80
+ blob = _make_blob(16, 32, 8, 128)
81
+ result = parse_state_blob(blob, n_kv_heads=8, head_dim=128)
82
+ assert result.keys.shape == (32, 8, 16, 128)
83
+ assert result.values.shape == (32, 8, 16, 128)
84
+
85
+ def test_parse_metadata(self) -> None:
86
+ blob = _make_blob(8, 32, 8, 128)
87
+ result = parse_state_blob(blob, n_kv_heads=8, head_dim=128)
88
+ assert result.n_cells == 8
89
+ assert result.n_layers == 32
90
+ assert result.arch == "llama"
91
+ assert result.v_trans is True
92
+ assert len(result.cells) == 8
93
+ assert result.cells[0].pos == 0
94
+ assert result.cells[7].pos == 7
95
+
96
+ def test_dtype_float16(self) -> None:
97
+ blob = _make_blob(4, 28, 8, 128)
98
+ result = parse_state_blob(blob, n_kv_heads=8, head_dim=128)
99
+ assert result.keys.dtype == torch.float16
100
+ assert result.values.dtype == torch.float16
101
+
102
+ def test_non_transposed_v(self) -> None:
103
+ blob = _make_blob(4, 28, 8, 128, v_trans=False)
104
+ result = parse_state_blob(blob, n_kv_heads=8, head_dim=128)
105
+ assert result.values.shape == (28, 8, 4, 128)
106
+ assert result.v_trans is False
107
+
108
+
109
+ class TestBlobParserErrors:
110
+ """Edge cases."""
111
+
112
+ def test_zero_cells_raises(self) -> None:
113
+ blob = struct.pack("<I", 5) + b"llama" + struct.pack("<II", 1, 0) + b"\x00" * 20
114
+ with pytest.raises(BlobParseError, match="0 cells"):
115
+ parse_state_blob(blob, n_kv_heads=8, head_dim=128)
116
+
117
+ def test_truncated_blob_raises(self) -> None:
118
+ blob = _make_blob(4, 28, 8, 128)
119
+ with pytest.raises(BlobParseError):
120
+ parse_state_blob(blob[:100], n_kv_heads=8, head_dim=128)
121
+
122
+ def test_bad_arch_length_raises(self) -> None:
123
+ blob = struct.pack("<I", 999) + b"x" * 100
124
+ with pytest.raises(BlobParseError, match="too large"):
125
+ parse_state_blob(blob, n_kv_heads=8, head_dim=128)
tests/test_block_pool.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — Block Pool Tests
3
+ Tests for 256-token block segmentation/assembly/extend.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import pytest
9
+ import torch
10
+
11
+ from kvcos.core.block_pool import BlockPool, KVBlock
12
+ from kvcos.core.types import BLOCK_SIZE_TOKENS
13
+
14
+
15
+ def _kv(n_layers: int, n_heads: int, ctx: int, dim: int) -> tuple[torch.Tensor, torch.Tensor]:
16
+ k = torch.randn(n_layers, n_heads, ctx, dim, dtype=torch.float16)
17
+ return k, k.clone()
18
+
19
+
20
+ class TestSegment:
21
+ """Segment full KV cache into 256-token blocks."""
22
+
23
+ def test_exact_blocks(self) -> None:
24
+ keys, vals = _kv(32, 8, 512, 128)
25
+ pool = BlockPool(agent_id="a", model_id="m")
26
+ blocks = pool.segment(keys, vals)
27
+ assert len(blocks) == 2
28
+ assert all(b.is_full for b in blocks)
29
+
30
+ def test_partial_last_block(self) -> None:
31
+ keys, vals = _kv(32, 8, 300, 128)
32
+ pool = BlockPool(agent_id="a", model_id="m")
33
+ blocks = pool.segment(keys, vals)
34
+ assert len(blocks) == 2
35
+ assert blocks[0].is_full
36
+ assert not blocks[1].is_full
37
+ assert blocks[1].block_len == 44
38
+
39
+ def test_total_tokens(self) -> None:
40
+ keys, vals = _kv(32, 8, 700, 128)
41
+ pool = BlockPool(agent_id="a", model_id="m")
42
+ pool.segment(keys, vals)
43
+ assert pool.total_tokens == 700
44
+
45
+
46
+ class TestAssemble:
47
+ """Assemble blocks back into full KV cache."""
48
+
49
+ def test_round_trip(self) -> None:
50
+ keys, vals = _kv(4, 2, 512, 64)
51
+ pool = BlockPool(agent_id="a", model_id="m")
52
+ pool.segment(keys, vals)
53
+ k_out, v_out = pool.assemble()
54
+ assert torch.equal(k_out, keys)
55
+
56
+ def test_subset_assembly(self) -> None:
57
+ keys, vals = _kv(4, 2, 768, 64)
58
+ pool = BlockPool(agent_id="a", model_id="m")
59
+ pool.segment(keys, vals)
60
+ k_out, _ = pool.assemble(block_indices=[0, 2])
61
+ assert k_out.shape[2] == BLOCK_SIZE_TOKENS * 2
62
+
63
+ def test_empty_raises(self) -> None:
64
+ pool = BlockPool(agent_id="a", model_id="m")
65
+ with pytest.raises(ValueError, match="No blocks"):
66
+ pool.assemble()
67
+
68
+
69
+ class TestExtend:
70
+ """Extend pool with new tokens."""
71
+
72
+ def test_fills_partial_block(self) -> None:
73
+ keys, vals = _kv(4, 2, 200, 64)
74
+ pool = BlockPool(agent_id="a", model_id="m")
75
+ pool.segment(keys, vals)
76
+ assert not pool.blocks[-1].is_full
77
+
78
+ new_k, new_v = _kv(4, 2, 56, 64)
79
+ pool.extend(new_k, new_v)
80
+ assert pool.blocks[-1].is_full
81
+ assert pool.total_tokens == 256
82
+
83
+ def test_extend_creates_new_blocks(self) -> None:
84
+ keys, vals = _kv(4, 2, 256, 64)
85
+ pool = BlockPool(agent_id="a", model_id="m")
86
+ pool.segment(keys, vals)
87
+ assert pool.n_blocks == 1
88
+
89
+ new_k, new_v = _kv(4, 2, 300, 64)
90
+ pool.extend(new_k, new_v)
91
+ assert pool.n_blocks == 3
92
+ assert pool.total_tokens == 556
tests/test_chunker.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for kvcos.engram.chunker — markdown-aware semantic chunker."""
2
+
3
+ import pytest
4
+
5
+ from kvcos.engram.chunker import Chunk, chunk_markdown, eng_filename, slug_from_path
6
+
7
+
8
+ class TestChunkMarkdown:
9
+ def test_empty_content(self):
10
+ assert chunk_markdown("") == []
11
+ assert chunk_markdown(" ") == []
12
+
13
+ def test_small_file_single_chunk(self):
14
+ content = "# Title\n\nSome short content."
15
+ chunks = chunk_markdown(content, max_chars=2000)
16
+ assert len(chunks) == 1
17
+ assert chunks[0].index == 0
18
+ assert chunks[0].char_start == 0
19
+ assert chunks[0].char_end == len(content)
20
+
21
+ def test_large_file_splits(self):
22
+ # Create content that exceeds max_chars
23
+ content = "# Section 1\n\n" + "A" * 1500 + "\n\n# Section 2\n\n" + "B" * 1500
24
+ chunks = chunk_markdown(content, max_chars=2000)
25
+ assert len(chunks) >= 2
26
+
27
+ def test_chunks_cover_full_content(self):
28
+ content = "# A\n\nText A.\n\n# B\n\nText B.\n\n# C\n\nText C."
29
+ chunks = chunk_markdown(content, max_chars=15)
30
+ # All original content should be present across chunks
31
+ combined = " ".join(c.raw_text for c in chunks)
32
+ for word in ["Text A", "Text B", "Text C"]:
33
+ assert word in combined
34
+
35
+ def test_context_prefix(self):
36
+ content = "Hello world"
37
+ chunks = chunk_markdown(content, context_prefix="Source: test.md")
38
+ assert len(chunks) == 1
39
+ assert chunks[0].text.startswith("Source: test.md")
40
+
41
+ def test_indices_sequential(self):
42
+ content = "# A\n\n" + "X" * 3000 + "\n\n# B\n\n" + "Y" * 3000
43
+ chunks = chunk_markdown(content, max_chars=2000)
44
+ for i, chunk in enumerate(chunks):
45
+ assert chunk.index == i
46
+
47
+ def test_merge_small_sections(self):
48
+ """Small consecutive sections should merge into one chunk."""
49
+ content = "# A\n\nShort.\n\n# B\n\nAlso short.\n\n# C\n\nStill short."
50
+ chunks = chunk_markdown(content, max_chars=2000, min_chars=100)
51
+ # All three small sections should merge into 1 chunk
52
+ assert len(chunks) == 1
53
+
54
+ def test_paragraph_split_fallback(self):
55
+ """Content without headers should split on paragraphs."""
56
+ paragraphs = ["Paragraph " + str(i) + ". " + "X" * 500
57
+ for i in range(6)]
58
+ content = "\n\n".join(paragraphs)
59
+ chunks = chunk_markdown(content, max_chars=1500)
60
+ assert len(chunks) >= 2
61
+
62
+
63
+ class TestSlugFromPath:
64
+ def test_simple_filename(self):
65
+ assert slug_from_path("readme.md") == "readme"
66
+
67
+ def test_uppercase_underscores(self):
68
+ assert slug_from_path("EIGENGRAM_SPEC.md") == "eigengram-spec"
69
+
70
+ def test_already_kebab(self):
71
+ assert slug_from_path("coding-style.md") == "coding-style"
72
+
73
+ def test_full_path(self):
74
+ assert slug_from_path("/Users/test/docs/my_doc.md") == "my-doc"
75
+
76
+ def test_special_chars(self):
77
+ assert slug_from_path("file (copy).md") == "file-copy"
78
+
79
+
80
+ class TestEngFilename:
81
+ def test_single_chunk(self):
82
+ name = eng_filename("engram", "readme", "2026-04-02")
83
+ assert name == "readme_2026-04-02.eng"
84
+
85
+ def test_multi_chunk(self):
86
+ name = eng_filename("engram", "geodesic3", "2026-04-02",
87
+ chunk_index=0, chunk_total=5)
88
+ assert name == "geodesic3_001_2026-04-02.eng"
89
+
90
+ def test_with_time(self):
91
+ name = eng_filename("engram", "session", "2026-04-02",
92
+ time_str="1430")
93
+ assert name == "session_2026-04-02_1430.eng"
94
+
95
+ def test_single_chunk_no_index(self):
96
+ """Single-chunk files should not have chunk number."""
97
+ name = eng_filename("engram", "small", "2026-04-02",
98
+ chunk_index=0, chunk_total=1)
99
+ assert name == "small_2026-04-02.eng"
tests/test_compression.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — Compression Tests
3
+
4
+
5
+ Tests for kvcos.core.compression:
6
+ - FP16 passthrough
7
+ - Q8_0 round-trip accuracy & shape preservation
8
+ - PolarQuant round-trip accuracy & rotation invariants
9
+ - Dispatcher routing and Q4_0 fallback warning
10
+ - Edge cases: padding, single-element groups
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import warnings
16
+
17
+ import pytest
18
+ import torch
19
+
20
+ from kvcos.core.compression import (
21
+ Q8_GROUP_SIZE,
22
+ CompressionResult,
23
+ compress,
24
+ compress_fp16,
25
+ compress_polarquant,
26
+ compress_q8_0,
27
+ decompress,
28
+ decompress_fp16,
29
+ decompress_polarquant,
30
+ decompress_q8_0,
31
+ )
32
+ from kvcos.core.types import CompressionMethod
33
+
34
+
35
+ # ── FP16 Passthrough ──────────────────────────────────────────────────────────
36
+
37
+
38
+ class TestFP16:
39
+ """FP16 passthrough: no quantization, just dtype normalization."""
40
+
41
+ def test_fp16_passthrough_shape(
42
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
43
+ ) -> None:
44
+ keys, _ = llama_kv_256
45
+ result = compress_fp16(keys)
46
+ assert result.data.shape == keys.shape
47
+
48
+ def test_fp16_passthrough_dtype(
49
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
50
+ ) -> None:
51
+ keys, _ = llama_kv_256
52
+ result = compress_fp16(keys)
53
+ assert result.data.dtype == torch.float16
54
+
55
+ def test_fp16_passthrough_exact(
56
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
57
+ ) -> None:
58
+ keys, _ = llama_kv_256
59
+ result = compress_fp16(keys)
60
+ assert torch.equal(result.data, keys.to(torch.float16))
61
+
62
+ def test_fp16_compression_ratio_one(
63
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
64
+ ) -> None:
65
+ keys, _ = llama_kv_256
66
+ result = compress_fp16(keys)
67
+ assert result.compression_ratio == 1.0
68
+
69
+ def test_fp16_method_tag(
70
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
71
+ ) -> None:
72
+ keys, _ = llama_kv_256
73
+ result = compress_fp16(keys)
74
+ assert result.method == CompressionMethod.FP16
75
+
76
+ def test_fp16_from_fp32(self) -> None:
77
+ """FP32 input is cast to FP16."""
78
+ t = torch.randn(4, 8, 32, 128, dtype=torch.float32)
79
+ result = compress_fp16(t)
80
+ assert result.data.dtype == torch.float16
81
+ assert result.original_dtype == torch.float32
82
+
83
+ def test_fp16_decompress_identity(
84
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
85
+ ) -> None:
86
+ keys, _ = llama_kv_256
87
+ result = compress_fp16(keys)
88
+ out = decompress_fp16(result.data)
89
+ assert torch.equal(out, result.data)
90
+
91
+
92
+ # ── Q8_0 Quantization ────────────────────────────────────────────────────────
93
+
94
+
95
+ class TestQ8_0:
96
+ """Q8_0: group quantization matching llama.cpp GGML_TYPE_Q8_0."""
97
+
98
+ def test_q8_0_shape_preserved(
99
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
100
+ ) -> None:
101
+ keys, _ = llama_kv_256
102
+ result = compress_q8_0(keys)
103
+ assert result.data.shape == keys.shape
104
+
105
+ def test_q8_0_output_dtype(
106
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
107
+ ) -> None:
108
+ """Q8_0 stores dequantized bfloat16 for safetensors compat."""
109
+ keys, _ = llama_kv_256
110
+ result = compress_q8_0(keys)
111
+ assert result.data.dtype == torch.bfloat16
112
+
113
+ def test_q8_0_method_tag(
114
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
115
+ ) -> None:
116
+ keys, _ = llama_kv_256
117
+ result = compress_q8_0(keys)
118
+ assert result.method == CompressionMethod.Q8_0
119
+
120
+ def test_q8_0_metadata_group_size(
121
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
122
+ ) -> None:
123
+ keys, _ = llama_kv_256
124
+ result = compress_q8_0(keys)
125
+ assert result.metadata["q8_group_size"] == str(Q8_GROUP_SIZE)
126
+
127
+ def test_q8_0_round_trip_low_error(
128
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
129
+ ) -> None:
130
+ """Q8_0 quantization error should be < 1% relative MSE."""
131
+ keys, _ = llama_kv_256
132
+ result = compress_q8_0(keys)
133
+ decompressed = decompress_q8_0(result.data)
134
+
135
+ original = keys.float()
136
+ restored = decompressed.float()
137
+
138
+ mse = ((original - restored) ** 2).mean()
139
+ signal_power = (original**2).mean()
140
+ relative_mse = (mse / signal_power).item()
141
+ assert relative_mse < 0.01, f"Q8_0 relative MSE {relative_mse:.6f} > 1%"
142
+
143
+ def test_q8_0_round_trip_values(
144
+ self, phi3_kv_256: tuple[torch.Tensor, torch.Tensor]
145
+ ) -> None:
146
+ """Q8_0 round-trip on Phi-3 (head_dim=96, needs padding)."""
147
+ keys, values = phi3_kv_256
148
+ for tensor in (keys, values):
149
+ result = compress_q8_0(tensor)
150
+ assert result.data.shape == tensor.shape
151
+
152
+ def test_q8_0_compression_ratio_fp32(self) -> None:
153
+ """FP32 input → bfloat16 output gives 2x compression ratio."""
154
+ t = torch.randn(2, 4, 64, 128, dtype=torch.float32)
155
+ result = compress_q8_0(t)
156
+ assert abs(result.compression_ratio - 2.0) < 0.01
157
+
158
+ def test_q8_0_compression_ratio_fp16(
159
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
160
+ ) -> None:
161
+ """FP16 input → bfloat16 output gives 1x ratio (same byte width)."""
162
+ keys, _ = llama_kv_256
163
+ result = compress_q8_0(keys)
164
+ assert abs(result.compression_ratio - 1.0) < 0.01
165
+
166
+ def test_q8_0_preserves_original_dtype(self) -> None:
167
+ t = torch.randn(4, 8, 32, 128, dtype=torch.float32)
168
+ result = compress_q8_0(t)
169
+ assert result.original_dtype == torch.float32
170
+
171
+ def test_q8_0_padding_dim_not_divisible(self) -> None:
172
+ """Head dims not divisible by 32 get padded then unpadded."""
173
+ t = torch.randn(2, 4, 16, 96, dtype=torch.float16) # 96 = 3*32, exact
174
+ result = compress_q8_0(t)
175
+ assert result.data.shape == t.shape
176
+
177
+ t2 = torch.randn(2, 4, 16, 100, dtype=torch.float16) # 100 not div by 32
178
+ result2 = compress_q8_0(t2)
179
+ assert result2.data.shape == t2.shape
180
+
181
+ def test_q8_0_zero_tensor(self) -> None:
182
+ """All-zero tensor should round-trip exactly."""
183
+ t = torch.zeros(2, 4, 16, 128, dtype=torch.float16)
184
+ result = compress_q8_0(t)
185
+ decompressed = decompress_q8_0(result.data)
186
+ assert torch.allclose(decompressed, t.to(torch.float16), atol=1e-6)
187
+
188
+
189
+ # ── PolarQuant ───────────────────────────────────────────────────────────────
190
+
191
+
192
+ class TestPolarQuant:
193
+ """PolarQuant: MSE-optimal random rotation + Lloyd-Max at 3 bits.
194
+ QJL intentionally absent (D5).
195
+ """
196
+
197
+ def test_polarquant_shape_preserved(
198
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
199
+ ) -> None:
200
+ keys, _ = llama_kv_256
201
+ result = compress_polarquant(keys)
202
+ assert result.data.shape == keys.shape
203
+
204
+ def test_polarquant_output_dtype(
205
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
206
+ ) -> None:
207
+ keys, _ = llama_kv_256
208
+ result = compress_polarquant(keys)
209
+ assert result.data.dtype == torch.bfloat16
210
+
211
+ def test_polarquant_method_tag(
212
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
213
+ ) -> None:
214
+ keys, _ = llama_kv_256
215
+ result = compress_polarquant(keys)
216
+ assert result.method == CompressionMethod.POLARQUANT
217
+
218
+ def test_polarquant_metadata_qjl_disabled(
219
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
220
+ ) -> None:
221
+ """D5: QJL must be marked disabled in metadata."""
222
+ keys, _ = llama_kv_256
223
+ result = compress_polarquant(keys)
224
+ assert result.metadata["qjl_enabled"] == "false"
225
+ assert result.metadata["polarquant_bits"] == "3"
226
+
227
+ def test_polarquant_round_trip_bounded_error(
228
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
229
+ ) -> None:
230
+ """PolarQuant 3-bit error should be < 15% relative MSE.
231
+
232
+ 3-bit Lloyd-Max on rotated Gaussian: theoretical ~10% for 8 centroids.
233
+ Allow margin for rotation + dtype casting.
234
+ """
235
+ keys, _ = llama_kv_256
236
+ result = compress_polarquant(keys)
237
+ decompressed = decompress_polarquant(result.data)
238
+
239
+ original = keys.float()
240
+ restored = decompressed.float()
241
+
242
+ mse = ((original - restored) ** 2).mean()
243
+ signal_power = (original**2).mean()
244
+ relative_mse = (mse / signal_power).item()
245
+ assert relative_mse < 0.15, f"PolarQuant relative MSE {relative_mse:.4f} > 15%"
246
+
247
+ def test_polarquant_worse_than_q8_0(
248
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
249
+ ) -> None:
250
+ """3-bit PolarQuant should have higher error than 8-bit Q8_0."""
251
+ keys, _ = llama_kv_256
252
+ original = keys.float()
253
+
254
+ q8_result = compress_q8_0(keys)
255
+ pq_result = compress_polarquant(keys)
256
+
257
+ q8_mse = ((original - decompress_q8_0(q8_result.data).float()) ** 2).mean()
258
+ pq_mse = (
259
+ (original - decompress_polarquant(pq_result.data).float()) ** 2
260
+ ).mean()
261
+
262
+ assert pq_mse > q8_mse, "PolarQuant 3-bit should be less accurate than Q8_0"
263
+
264
+ def test_polarquant_deterministic(
265
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
266
+ ) -> None:
267
+ """Same input → same output (fixed seed rotation matrix)."""
268
+ keys, _ = llama_kv_256
269
+ r1 = compress_polarquant(keys)
270
+ r2 = compress_polarquant(keys)
271
+ assert torch.equal(r1.data, r2.data)
272
+
273
+ def test_polarquant_phi3_shape(
274
+ self, phi3_kv_256: tuple[torch.Tensor, torch.Tensor]
275
+ ) -> None:
276
+ """Phi-3 head_dim=96 works with PolarQuant."""
277
+ keys, _ = phi3_kv_256
278
+ result = compress_polarquant(keys)
279
+ assert result.data.shape == keys.shape
280
+
281
+
282
+ # ── Dispatcher ───────────────────────────────────────────────────────────────
283
+
284
+
285
+ class TestDispatcher:
286
+ """compress() and decompress() dispatch to correct implementations."""
287
+
288
+ @pytest.mark.parametrize(
289
+ "method",
290
+ [CompressionMethod.FP16, CompressionMethod.Q8_0, CompressionMethod.POLARQUANT],
291
+ )
292
+ def test_compress_dispatches(self, method: CompressionMethod) -> None:
293
+ t = torch.randn(2, 4, 16, 128, dtype=torch.float16)
294
+ result = compress(t, method)
295
+ assert isinstance(result, CompressionResult)
296
+ assert result.method == method
297
+
298
+ @pytest.mark.parametrize(
299
+ "method",
300
+ [CompressionMethod.FP16, CompressionMethod.Q8_0, CompressionMethod.POLARQUANT],
301
+ )
302
+ def test_decompress_returns_fp16(self, method: CompressionMethod) -> None:
303
+ t = torch.randn(2, 4, 16, 128, dtype=torch.float16)
304
+ result = compress(t, method)
305
+ out = decompress(result.data, method)
306
+ assert out.dtype == torch.float16
307
+
308
+ def test_q4_0_warns_and_falls_back(self) -> None:
309
+ """D5: Q4_0 emits warning and uses Q8_0 instead."""
310
+ t = torch.randn(2, 4, 16, 128, dtype=torch.float16)
311
+ with warnings.catch_warnings(record=True) as w:
312
+ warnings.simplefilter("always")
313
+ result = compress(t, CompressionMethod.Q4_0)
314
+ assert len(w) == 1
315
+ assert "Q4_0" in str(w[0].message)
316
+ assert "92%" in str(w[0].message)
317
+ assert result.method == CompressionMethod.Q8_0
318
+
319
+ def test_unknown_method_raises(self) -> None:
320
+ t = torch.randn(2, 4, 16, 128, dtype=torch.float16)
321
+ with pytest.raises(ValueError, match="Unknown compression method"):
322
+ compress(t, "invalid_method") # type: ignore[arg-type]
323
+
324
+ def test_decompress_unknown_raises(self) -> None:
325
+ t = torch.randn(2, 4, 16, 128, dtype=torch.float16)
326
+ with pytest.raises(ValueError, match="Unknown compression method"):
327
+ decompress(t, "invalid_method") # type: ignore[arg-type]
328
+
329
+
330
+ # ── Round-trip Integration ───────────────────────────────────────────────────
331
+
332
+
333
+ class TestRoundTrip:
334
+ """Full compress → decompress round-trip through dispatcher."""
335
+
336
+ @pytest.mark.parametrize(
337
+ "method",
338
+ [CompressionMethod.FP16, CompressionMethod.Q8_0, CompressionMethod.POLARQUANT],
339
+ )
340
+ def test_round_trip_shape_preserved(self, method: CompressionMethod) -> None:
341
+ t = torch.randn(4, 8, 64, 128, dtype=torch.float16)
342
+ result = compress(t, method)
343
+ out = decompress(result.data, method)
344
+ assert out.shape == t.shape
345
+
346
+ def test_round_trip_both_kv(
347
+ self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
348
+ ) -> None:
349
+ """Compress and decompress both keys and values."""
350
+ keys, values = llama_kv_256
351
+ for tensor in (keys, values):
352
+ for method in (CompressionMethod.FP16, CompressionMethod.Q8_0):
353
+ result = compress(tensor, method)
354
+ out = decompress(result.data, method)
355
+ assert out.shape == tensor.shape
356
+ assert out.dtype == torch.float16
tests/test_eigengram.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EIGENGRAM test suite — no model calls, pure format verification.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import os
8
+ import struct
9
+
10
+ import pytest
11
+ import torch
12
+
13
+ from kvcos.engram.format import (
14
+ EigramDecoder,
15
+ EigramEncoder,
16
+ EIGENGRAM_MAGIC,
17
+ EIGENGRAM_VERSION,
18
+ )
19
+
20
+ BASIS_PATH = "results/corpus_basis_fcdb_v2.pt"
21
+
22
+
23
+ @pytest.fixture(scope="module")
24
+ def basis():
25
+ if not os.path.exists(BASIS_PATH):
26
+ pytest.skip("FCDB v2 basis not built yet")
27
+ return torch.load(BASIS_PATH, weights_only=False)
28
+
29
+
30
+ @pytest.fixture(scope="module")
31
+ def sample_cert(basis):
32
+ enc = EigramEncoder()
33
+ R = basis["basis"].shape[0]
34
+ return enc.encode(
35
+ vec_perdoc=torch.randn(R),
36
+ vec_fcdb=torch.randn(R),
37
+ joint_center=basis["joint_center"],
38
+ corpus_hash="a" * 32,
39
+ model_id="Llama-3.1-8B",
40
+ basis_rank=R,
41
+ n_corpus=200,
42
+ layer_range=(8, 24),
43
+ context_len=512,
44
+ l2_norm=1.234,
45
+ scs=0.42,
46
+ margin_proof=0.013,
47
+ task_description="Test document for transformer attention.",
48
+ cache_id="test-doc-001",
49
+ )
50
+
51
+
52
+ class TestFormat:
53
+ def test_magic_present(self, sample_cert: bytes) -> None:
54
+ assert sample_cert[:4] == EIGENGRAM_MAGIC
55
+
56
+ def test_version_byte(self, sample_cert: bytes) -> None:
57
+ assert struct.unpack_from("<B", sample_cert, 4)[0] == EIGENGRAM_VERSION
58
+
59
+ def test_minimum_size(self, sample_cert: bytes, basis) -> None:
60
+ R = basis["basis"].shape[0]
61
+ min_size = 99 + R * 2 + R * 2 + 128 * 2
62
+ assert len(sample_cert) >= min_size
63
+
64
+ def test_file_size_reasonable(self, sample_cert: bytes) -> None:
65
+ assert len(sample_cert) < 2048
66
+
67
+
68
+ class TestRoundTrip:
69
+ def test_model_id(self, sample_cert: bytes) -> None:
70
+ rec = EigramDecoder().decode(sample_cert)
71
+ assert rec["model_id"] == "Llama-3.1-8B"
72
+
73
+ def test_basis_rank(self, sample_cert: bytes, basis) -> None:
74
+ rec = EigramDecoder().decode(sample_cert)
75
+ assert rec["basis_rank"] == basis["basis"].shape[0]
76
+
77
+ def test_vec_perdoc_shape(self, sample_cert: bytes, basis) -> None:
78
+ rec = EigramDecoder().decode(sample_cert)
79
+ assert rec["vec_perdoc"].shape == (basis["basis"].shape[0],)
80
+
81
+ def test_vec_fcdb_shape(self, sample_cert: bytes, basis) -> None:
82
+ rec = EigramDecoder().decode(sample_cert)
83
+ assert rec["vec_fcdb"].shape == (basis["basis"].shape[0],)
84
+
85
+ def test_joint_center_shape(self, sample_cert: bytes) -> None:
86
+ rec = EigramDecoder().decode(sample_cert)
87
+ assert rec["joint_center"].shape == (128,)
88
+
89
+ def test_scs(self, sample_cert: bytes) -> None:
90
+ rec = EigramDecoder().decode(sample_cert)
91
+ assert abs(rec["scs"] - 0.42) < 0.01
92
+
93
+ def test_margin_proof(self, sample_cert: bytes) -> None:
94
+ rec = EigramDecoder().decode(sample_cert)
95
+ assert abs(rec["margin_proof"] - 0.013) < 0.001
96
+
97
+ def test_task_description(self, sample_cert: bytes) -> None:
98
+ rec = EigramDecoder().decode(sample_cert)
99
+ assert "transformer" in rec["task_description"]
100
+
101
+ def test_cache_id(self, sample_cert: bytes) -> None:
102
+ rec = EigramDecoder().decode(sample_cert)
103
+ assert rec["cache_id"] == "test-doc-001"
104
+
105
+ def test_layer_range(self, sample_cert: bytes) -> None:
106
+ rec = EigramDecoder().decode(sample_cert)
107
+ assert rec["layer_range"] == (8, 24)
108
+
109
+ def test_n_corpus(self, sample_cert: bytes) -> None:
110
+ rec = EigramDecoder().decode(sample_cert)
111
+ assert rec["n_corpus"] == 200
112
+
113
+ def test_context_len(self, sample_cert: bytes) -> None:
114
+ rec = EigramDecoder().decode(sample_cert)
115
+ assert rec["context_len"] == 512
116
+
117
+ def test_float16_cosine_preserved(self, basis) -> None:
118
+ enc = EigramEncoder()
119
+ R = basis["basis"].shape[0]
120
+ v = torch.randn(R)
121
+ v = v / v.norm()
122
+ cert = enc.encode(
123
+ vec_perdoc=v, vec_fcdb=v,
124
+ joint_center=basis["joint_center"],
125
+ corpus_hash="a" * 32, model_id="test",
126
+ basis_rank=R, n_corpus=200,
127
+ layer_range=(8, 24), context_len=0,
128
+ l2_norm=1.0, scs=0.5, margin_proof=0.0,
129
+ task_description="cosine test", cache_id="cos",
130
+ )
131
+ rec = EigramDecoder().decode(cert)
132
+ cos = torch.nn.functional.cosine_similarity(
133
+ v.unsqueeze(0), rec["vec_perdoc"].unsqueeze(0)
134
+ ).item()
135
+ assert cos > 0.999, f"Cosine after round-trip: {cos:.5f}"
136
+
137
+
138
+ class TestErrorHandling:
139
+ def test_bad_magic_raises(self) -> None:
140
+ bad = b"XXXX" + b"\x00" * 200
141
+ with pytest.raises(ValueError, match="magic"):
142
+ EigramDecoder().decode(bad)
143
+
144
+ def test_wrong_version_raises(self, sample_cert: bytes) -> None:
145
+ data = bytearray(sample_cert)
146
+ data[4] = 99
147
+ with pytest.raises(ValueError, match="version"):
148
+ EigramDecoder().decode(bytes(data))
149
+
150
+ def test_truncated_raises(self, sample_cert: bytes) -> None:
151
+ with pytest.raises(Exception):
152
+ EigramDecoder().decode(sample_cert[:20])
tests/test_embedder.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for kvcos.engram.embedder — unified fingerprint embedding."""
2
+
3
+ import pytest
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from kvcos.engram.embedder import (
8
+ HashEmbedder,
9
+ get_embedder,
10
+ get_fingerprint,
11
+ reset_embedder,
12
+ )
13
+
14
+
15
+ class TestHashEmbedder:
16
+ def test_deterministic(self):
17
+ emb = HashEmbedder(dim=128)
18
+ fp1 = emb.embed("hello")
19
+ fp2 = emb.embed("hello")
20
+ assert torch.allclose(fp1, fp2)
21
+
22
+ def test_different_text(self):
23
+ emb = HashEmbedder(dim=128)
24
+ fp1 = emb.embed("hello")
25
+ fp2 = emb.embed("world")
26
+ assert not torch.allclose(fp1, fp2)
27
+
28
+ def test_normalized(self):
29
+ emb = HashEmbedder(dim=128)
30
+ fp = emb.embed("test")
31
+ norm = torch.norm(fp).item()
32
+ assert abs(norm - 1.0) < 0.01
33
+
34
+ def test_dimension(self):
35
+ emb = HashEmbedder(dim=256)
36
+ fp = emb.embed("test")
37
+ assert fp.shape == (256,)
38
+ assert emb.dim == 256
39
+
40
+ def test_source_tag(self):
41
+ emb = HashEmbedder()
42
+ assert emb.source == "hash-fallback"
43
+
44
+
45
+ class TestGetFingerprint:
46
+ def test_returns_tensor_and_source(self):
47
+ fp, source = get_fingerprint("test text")
48
+ assert isinstance(fp, torch.Tensor)
49
+ assert isinstance(source, str)
50
+ assert source in ("llama_cpp", "sbert", "hash-fallback")
51
+
52
+ def test_deterministic(self):
53
+ fp1, _ = get_fingerprint("same text")
54
+ fp2, _ = get_fingerprint("same text")
55
+ assert torch.allclose(fp1, fp2)
56
+
57
+
58
+ class TestSBertEmbedder:
59
+ """Test sbert if available (installed in this venv)."""
60
+
61
+ def test_sbert_available(self):
62
+ """Verify sentence-transformers is usable."""
63
+ try:
64
+ from kvcos.engram.embedder import SBertEmbedder
65
+ emb = SBertEmbedder()
66
+ assert emb.source == "sbert"
67
+ assert emb.dim == 384
68
+ except ImportError:
69
+ pytest.skip("sentence-transformers not installed")
70
+
71
+ def test_semantic_discrimination(self):
72
+ """Related texts should be more similar than unrelated."""
73
+ try:
74
+ from kvcos.engram.embedder import SBertEmbedder
75
+ emb = SBertEmbedder()
76
+ except ImportError:
77
+ pytest.skip("sentence-transformers not installed")
78
+
79
+ fp_a = emb.embed("machine learning neural network training")
80
+ fp_b = emb.embed("deep learning model optimization")
81
+ fp_c = emb.embed("chocolate cake baking recipe")
82
+
83
+ sim_ab = F.cosine_similarity(fp_a.unsqueeze(0), fp_b.unsqueeze(0)).item()
84
+ sim_ac = F.cosine_similarity(fp_a.unsqueeze(0), fp_c.unsqueeze(0)).item()
85
+
86
+ assert sim_ab > sim_ac, (
87
+ f"Related topics ({sim_ab:.4f}) should be more similar "
88
+ f"than unrelated ({sim_ac:.4f})"
89
+ )
90
+
91
+
92
+ class TestGetEmbedder:
93
+ def test_singleton(self):
94
+ reset_embedder()
95
+ e1 = get_embedder()
96
+ e2 = get_embedder()
97
+ assert e1 is e2
98
+
99
+ def test_reset(self):
100
+ reset_embedder()
101
+ e1 = get_embedder()
102
+ reset_embedder()
103
+ e2 = get_embedder()
104
+ # After reset, a new instance is created
105
+ # (may or may not be same object depending on strategy)
106
+ assert e2 is not None
tests/test_hnsw_index.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for kvcos.engram.hnsw_index — HNSW nearest-neighbor index."""
2
+
3
+ import torch
4
+ import pytest
5
+
6
+ from kvcos.engram.hnsw_index import EngramIndex, HNSWResult
7
+
8
+
9
+ @pytest.fixture
10
+ def small_index():
11
+ """Build a small 8-dim HNSW index with 5 documents."""
12
+ idx = EngramIndex(dim=8)
13
+ ids = [f"doc_{i}" for i in range(5)]
14
+ # deterministic orthogonal-ish vectors
15
+ vecs = torch.eye(5, 8)
16
+ idx.add_batch(ids, vecs)
17
+ return idx
18
+
19
+
20
+ class TestEngramIndexBuild:
21
+ def test_add_batch_len(self, small_index):
22
+ assert len(small_index) == 5
23
+
24
+ def test_add_batch_ids_stored(self, small_index):
25
+ assert small_index._ids == [f"doc_{i}" for i in range(5)]
26
+
27
+ def test_repr(self, small_index):
28
+ r = repr(small_index)
29
+ assert "n=5" in r
30
+ assert "dim=8" in r
31
+
32
+
33
+ class TestEngramIndexSearch:
34
+ def test_search_returns_results(self, small_index):
35
+ query = torch.eye(5, 8)[0] # matches doc_0
36
+ results = small_index.search(query, top_k=3)
37
+ assert len(results) == 3
38
+ assert results[0].doc_id == "doc_0"
39
+
40
+ def test_search_scores_descending(self, small_index):
41
+ query = torch.eye(5, 8)[2]
42
+ results = small_index.search(query, top_k=5)
43
+ scores = [r.score for r in results]
44
+ assert scores == sorted(scores, reverse=True)
45
+
46
+ def test_search_margin(self, small_index):
47
+ query = torch.eye(5, 8)[0]
48
+ results = small_index.search(query, top_k=3)
49
+ assert results[0].margin >= 0
50
+
51
+ def test_search_raises_before_build(self):
52
+ idx = EngramIndex(dim=8)
53
+ with pytest.raises(RuntimeError, match="not built"):
54
+ idx.search(torch.randn(8), top_k=1)
55
+
56
+
57
+ class TestEngramIndexGetVector:
58
+ def test_get_vector_returns_tensor(self, small_index):
59
+ vec = small_index.get_vector("doc_0")
60
+ assert vec is not None
61
+ assert isinstance(vec, torch.Tensor)
62
+ assert vec.shape == (8,)
63
+
64
+ def test_get_vector_none_for_missing(self, small_index):
65
+ vec = small_index.get_vector("nonexistent")
66
+ assert vec is None
67
+
68
+ def test_get_vector_reconstructs_normalized(self, small_index):
69
+ """Vectors are L2-normalized on add, so reconstruction should be unit-length."""
70
+ vec = small_index.get_vector("doc_0")
71
+ norm = torch.norm(vec).item()
72
+ assert abs(norm - 1.0) < 0.01
73
+
74
+ def test_get_vector_matches_original_direction(self, small_index):
75
+ """Reconstructed vector should point in the same direction as the original."""
76
+ original = torch.nn.functional.normalize(torch.eye(5, 8)[3:4], dim=-1)[0]
77
+ reconstructed = small_index.get_vector("doc_3")
78
+ cosine = torch.dot(original, reconstructed).item()
79
+ assert cosine > 0.99
80
+
81
+
82
+ class TestEngramIndexPersistence:
83
+ def test_save_and_load(self, small_index, tmp_path):
84
+ path = str(tmp_path / "test_hnsw")
85
+ small_index.save(path)
86
+
87
+ loaded = EngramIndex.load(path)
88
+ assert len(loaded) == 5
89
+ assert loaded._ids == small_index._ids
90
+
91
+ def test_loaded_search_matches_original(self, small_index, tmp_path):
92
+ path = str(tmp_path / "test_hnsw")
93
+ small_index.save(path)
94
+ loaded = EngramIndex.load(path)
95
+
96
+ query = torch.eye(5, 8)[1]
97
+ orig_results = small_index.search(query, top_k=3)
98
+ load_results = loaded.search(query, top_k=3)
99
+ assert [r.doc_id for r in orig_results] == [r.doc_id for r in load_results]
100
+
101
+ def test_loaded_get_vector(self, small_index, tmp_path):
102
+ path = str(tmp_path / "test_hnsw")
103
+ small_index.save(path)
104
+ loaded = EngramIndex.load(path)
105
+
106
+ vec = loaded.get_vector("doc_2")
107
+ assert vec is not None
108
+ assert vec.shape == (8,)
tests/test_integration_synthetic.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — Synthetic Integration Test
3
+ Full pipeline E2E with synthetic tensors — no real model needed.
4
+
5
+ Pipeline: create KV → extract state → serialize .eng → load → index → query → retrieve
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from pathlib import Path
11
+
12
+ import pytest
13
+ import torch
14
+ from safetensors.torch import load_file
15
+
16
+ from kvcos.core.cache_spec import LLAMA_3_1_8B
17
+ from kvcos.core.serializer import EngramSerializer
18
+ from kvcos.core.types import CompressionMethod, StateExtractionMode
19
+ from kvcos.core.manifold_index import ManifoldIndex
20
+ from kvcos.core.retriever import EGRRetriever
21
+ from kvcos.core.state_extractor import MARStateExtractor
22
+ from kvcos.storage.local import LocalStorageBackend
23
+ from tests.conftest import make_synthetic_kv
24
+
25
+
26
+ class TestFullPipeline:
27
+ """End-to-end: store → index → query → retrieve using synthetic data."""
28
+
29
+ def test_serialize_round_trip(self, tmp_data_dir: Path) -> None:
30
+ """Step 1-4: Create → serialize → load → verify shape."""
31
+ keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256)
32
+ assert keys.shape == (32, 8, 256, 128)
33
+
34
+ serializer = EngramSerializer()
35
+ eng_path = tmp_data_dir / "roundtrip.eng"
36
+
37
+ serializer.serialize(
38
+ keys=keys, values=values,
39
+ agent_id="integration-test", task_description="round-trip test",
40
+ model_id=LLAMA_3_1_8B["model_id"], output_path=eng_path,
41
+ compression=CompressionMethod.FP16,
42
+ )
43
+ assert eng_path.exists()
44
+
45
+ # Verify valid safetensors
46
+ tensors = load_file(str(eng_path))
47
+ assert "layer_0_keys" in tensors
48
+
49
+ k_out, v_out, meta = serializer.deserialize(eng_path)
50
+ assert k_out.shape == keys.shape
51
+ assert v_out.shape == values.shape
52
+
53
+ @pytest.mark.parametrize("mode", list(StateExtractionMode))
54
+ def test_extraction_all_modes(self, mode: StateExtractionMode) -> None:
55
+ """Step 2: Extract state vector in all 3 modes."""
56
+ keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256)
57
+ extractor = MARStateExtractor(mode=mode, rank=128)
58
+ result = extractor.extract(keys, LLAMA_3_1_8B)
59
+
60
+ assert result.state_vec.dim() == 1
61
+ assert result.state_vec.shape[0] > 0
62
+ assert result.l2_norm > 0
63
+ assert result.mode == mode
64
+
65
+ def test_index_and_query(self, tmp_data_dir: Path) -> None:
66
+ """Step 5-6: Index state vector → query with different tensor → get result."""
67
+ keys_a, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256, seed=42)
68
+ keys_b, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256, seed=99)
69
+
70
+ extractor = MARStateExtractor(
71
+ mode=StateExtractionMode.MEAN_POOL,
72
+ )
73
+ dim = extractor.output_dim(LLAMA_3_1_8B)
74
+ index = ManifoldIndex(dim=dim)
75
+
76
+ # Extract and index first tensor
77
+ from kvcos.core.manifold_index import IndexEntry
78
+
79
+ result_a = extractor.extract(keys_a, LLAMA_3_1_8B)
80
+ index.add(
81
+ result_a.state_vec,
82
+ IndexEntry(
83
+ cache_id="test-cache-a",
84
+ task_description="indexed engram",
85
+ model_id=LLAMA_3_1_8B["model_id"],
86
+ created_at="2026-01-01T00:00:00Z",
87
+ context_len=256,
88
+ l2_norm=result_a.l2_norm,
89
+ ),
90
+ )
91
+
92
+ # Query with second tensor
93
+ result_b = extractor.extract(keys_b, LLAMA_3_1_8B)
94
+ results = index.search(result_b.state_vec, top_k=1)
95
+
96
+ assert len(results) >= 1
97
+ assert results[0]["cache_id"] == "test-cache-a"
98
+
99
+ def test_full_egr_pipeline(self, tmp_data_dir: Path) -> None:
100
+ """Step 7: Full EGR retrieval — store → index → query → retrieve."""
101
+ keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256, seed=42)
102
+ query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256, seed=99)
103
+
104
+ extractor = MARStateExtractor(
105
+ mode=StateExtractionMode.MEAN_POOL,
106
+ )
107
+ dim = extractor.output_dim(LLAMA_3_1_8B)
108
+ index = ManifoldIndex(dim=dim)
109
+ storage = LocalStorageBackend(data_dir=tmp_data_dir)
110
+ retriever = EGRRetriever(extractor, index, storage)
111
+
112
+ # Store
113
+ cache_id = retriever.index_engram(
114
+ keys=keys, values=values, spec=LLAMA_3_1_8B,
115
+ agent_id="integration-test",
116
+ task_description="full pipeline test",
117
+ model_id=LLAMA_3_1_8B["model_id"],
118
+ output_dir=tmp_data_dir,
119
+ )
120
+ assert isinstance(cache_id, str)
121
+ assert index.n_entries == 1
122
+
123
+ # Retrieve
124
+ response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=1)
125
+ assert len(response.results) >= 1
126
+
127
+ result = response.results[0]
128
+ assert result.cache_id == cache_id
129
+ assert result.keys.shape == keys.shape
130
+ assert result.values.shape == values.shape
131
+ assert result.similarity != 0.0
132
+
133
+ def test_multi_engram_ranking(self, tmp_data_dir: Path) -> None:
134
+ """Store 3 engrams, query, verify results are ranked by similarity."""
135
+ extractor = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
136
+ dim = extractor.output_dim(LLAMA_3_1_8B)
137
+ index = ManifoldIndex(dim=dim)
138
+ storage = LocalStorageBackend(data_dir=tmp_data_dir)
139
+ retriever = EGRRetriever(extractor, index, storage)
140
+
141
+ for seed in (10, 20, 30):
142
+ keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64, seed=seed)
143
+ retriever.index_engram(
144
+ keys=keys, values=values, spec=LLAMA_3_1_8B,
145
+ agent_id="test", task_description=f"seed-{seed}",
146
+ model_id=LLAMA_3_1_8B["model_id"],
147
+ output_dir=tmp_data_dir,
148
+ )
149
+
150
+ assert index.n_entries == 3
151
+
152
+ query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64, seed=10)
153
+ response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=3)
154
+
155
+ assert len(response.results) == 3
156
+ # Results should be sorted by descending similarity
157
+ sims = [r.similarity for r in response.results]
158
+ assert sims == sorted(sims, reverse=True)
159
+ # Closest match should be seed=10 (same as query)
160
+ assert response.results[0].task_description == "seed-10"
tests/test_iswa_blob_parser.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — ISWA Blob Parser Tests
3
+ Tests for multi-section KV cache parsing (Gemma 4 ISWA format).
4
+
5
+ Uses synthetic ISWA blobs from conftest.make_synthetic_iswa_blob().
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import pytest
11
+ import torch
12
+
13
+ from kvcos.core.blob_parser import (
14
+ BlobParseError,
15
+ ParsedKVCache,
16
+ ParsedMultiSectionCache,
17
+ parse_multi_section_blob,
18
+ parse_state_blob,
19
+ )
20
+ from kvcos.core.types import AttentionType, CacheSection, ModelCacheSpec
21
+ from tests.conftest import (
22
+ GEMMA4_GLOBAL_SECTION,
23
+ GEMMA4_SECTIONS,
24
+ GEMMA4_SWA_SECTION,
25
+ make_synthetic_iswa_blob,
26
+ )
27
+
28
+
29
+ class TestParseMultiSectionBlob:
30
+ """Parse ISWA blobs with multiple KV cache sections."""
31
+
32
+ def test_parse_gemma4_shape(self) -> None:
33
+ blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
34
+ result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
35
+
36
+ assert len(result.sections) == 2
37
+
38
+ # Global section: [5, 2, 4, 512]
39
+ s0 = result.sections[0]
40
+ assert s0.keys.shape == (5, 2, 4, 512)
41
+ assert s0.values.shape == (5, 2, 4, 512)
42
+
43
+ # SWA section: [25, 8, 4, 256]
44
+ s1 = result.sections[1]
45
+ assert s1.keys.shape == (25, 8, 4, 256)
46
+ assert s1.values.shape == (25, 8, 4, 256)
47
+
48
+ def test_parse_metadata(self) -> None:
49
+ blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
50
+ result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
51
+
52
+ assert result.arch == "gemma4"
53
+ assert result.n_sections == 2
54
+ assert result.total_layers == 30
55
+
56
+ assert result.sections[0].n_layers == 5
57
+ assert result.sections[0].arch == "gemma4"
58
+ assert result.sections[1].n_layers == 25
59
+
60
+ def test_parse_cells(self) -> None:
61
+ blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
62
+ result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
63
+
64
+ for sec in result.sections:
65
+ assert sec.n_cells == 4
66
+ assert len(sec.cells) == 4
67
+ assert sec.cells[0].pos == 0
68
+ assert sec.cells[3].pos == 3
69
+
70
+ def test_dtype_float16(self) -> None:
71
+ blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=2)
72
+ result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
73
+
74
+ for sec in result.sections:
75
+ assert sec.keys.dtype == torch.float16
76
+ assert sec.values.dtype == torch.float16
77
+
78
+ def test_different_cell_counts(self) -> None:
79
+ blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=8)
80
+ result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
81
+
82
+ assert result.sections[0].n_cells == 8
83
+ assert result.sections[1].n_cells == 8
84
+
85
+ def test_non_transposed_v(self) -> None:
86
+ blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=2, v_trans=False)
87
+ result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
88
+
89
+ for sec in result.sections:
90
+ assert sec.v_trans is False
91
+
92
+ def test_single_section_works(self) -> None:
93
+ """Single-section ISWA parse should work identically to standard."""
94
+ single = (GEMMA4_GLOBAL_SECTION,)
95
+ blob = make_synthetic_iswa_blob(single, n_cells=4)
96
+ result = parse_multi_section_blob(blob, single)
97
+
98
+ assert len(result.sections) == 1
99
+ assert result.sections[0].keys.shape == (5, 2, 4, 512)
100
+
101
+
102
+ class TestParseMultiSectionErrors:
103
+ """Error handling for ISWA blob parsing."""
104
+
105
+ def test_section_mismatch_raises(self) -> None:
106
+ """Blob has 2 sections but we pass specs for 3."""
107
+ blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
108
+ three_sections = GEMMA4_SECTIONS + (GEMMA4_GLOBAL_SECTION,)
109
+ with pytest.raises(BlobParseError, match="Expected 3.*got 2"):
110
+ parse_multi_section_blob(blob, three_sections)
111
+
112
+ def test_truncated_blob_raises(self) -> None:
113
+ blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
114
+ with pytest.raises(BlobParseError):
115
+ parse_multi_section_blob(blob[:100], GEMMA4_SECTIONS)
116
+
117
+ def test_wrong_dimensions_raises(self) -> None:
118
+ """Pass wrong KV head count for a section."""
119
+ wrong_sections = (
120
+ CacheSection(AttentionType.FULL, 5, 4, 512), # wrong: 4 heads not 2
121
+ GEMMA4_SWA_SECTION,
122
+ )
123
+ blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
124
+ with pytest.raises(BlobParseError):
125
+ parse_multi_section_blob(blob, wrong_sections)
126
+
127
+
128
+ class TestStandardBlobBackwardCompat:
129
+ """Ensure parse_state_blob still works for single-stream blobs."""
130
+
131
+ def test_single_stream_still_works(self) -> None:
132
+ from tests.test_blob_parser import _make_blob
133
+
134
+ blob = _make_blob(16, 32, 8, 128)
135
+ result = parse_state_blob(blob, n_kv_heads=8, head_dim=128)
136
+ assert result.keys.shape == (32, 8, 16, 128)
tests/test_iswa_bridge.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — ISWA Bridge Tests
3
+ Tests for multi-architecture metadata detection and ISWA cache extraction.
4
+ Does NOT require a real GGUF model — tests the metadata helpers and spec logic.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import pytest
10
+
11
+ from integrations.llama_cpp_bridge import _meta_get
12
+ from kvcos.core.cache_spec import (
13
+ GEMMA_4_26B_A4B,
14
+ LLAMA_3_1_8B,
15
+ get_model_spec,
16
+ is_iswa_spec,
17
+ )
18
+
19
+
20
+ class TestMetaGet:
21
+ """Metadata key fallback chain across architecture prefixes."""
22
+
23
+ def test_llama_prefix(self) -> None:
24
+ meta = {"llama.block_count": "32"}
25
+ assert _meta_get(meta, "block_count") == "32"
26
+
27
+ def test_gemma4_prefix(self) -> None:
28
+ meta = {"gemma4.block_count": "30"}
29
+ assert _meta_get(meta, "block_count") == "30"
30
+
31
+ def test_gemma_prefix(self) -> None:
32
+ meta = {"gemma.attention.head_count": "8"}
33
+ assert _meta_get(meta, "attention.head_count") == "8"
34
+
35
+ def test_general_fallback(self) -> None:
36
+ meta = {"general.block_count": "28"}
37
+ assert _meta_get(meta, "block_count") == "28"
38
+
39
+ def test_default_when_missing(self) -> None:
40
+ meta = {}
41
+ assert _meta_get(meta, "block_count", "32") == "32"
42
+
43
+ def test_llama_takes_priority(self) -> None:
44
+ meta = {
45
+ "llama.block_count": "32",
46
+ "gemma4.block_count": "30",
47
+ "general.block_count": "28",
48
+ }
49
+ assert _meta_get(meta, "block_count") == "32"
50
+
51
+ def test_gemma4_before_general(self) -> None:
52
+ meta = {
53
+ "gemma4.embedding_length": "3072",
54
+ "general.embedding_length": "4096",
55
+ }
56
+ assert _meta_get(meta, "embedding_length") == "3072"
57
+
58
+
59
+ class TestISWASpecDetection:
60
+ """Registry and ISWA detection."""
61
+
62
+ def test_gemma4_in_registry(self) -> None:
63
+ spec = get_model_spec("google/gemma-4-26b-a4b-it")
64
+ assert spec is not None
65
+ assert spec["model_family"] == "gemma"
66
+
67
+ def test_gemma4_is_iswa(self) -> None:
68
+ assert is_iswa_spec(GEMMA_4_26B_A4B) is True
69
+
70
+ def test_llama_not_iswa(self) -> None:
71
+ assert is_iswa_spec(LLAMA_3_1_8B) is False
72
+
73
+ def test_gemma4_sections_correct(self) -> None:
74
+ sections = GEMMA_4_26B_A4B["cache_sections"]
75
+ assert len(sections) == 2
76
+
77
+ # Global section
78
+ assert sections[0].n_layers == 5
79
+ assert sections[0].n_kv_heads == 2
80
+ assert sections[0].head_dim == 512
81
+ assert sections[0].n_embd_kv == 1024
82
+
83
+ # SWA section
84
+ assert sections[1].n_layers == 25
85
+ assert sections[1].n_kv_heads == 8
86
+ assert sections[1].head_dim == 256
87
+ assert sections[1].window_size == 1024
88
+
89
+ def test_gemma4_total_layers(self) -> None:
90
+ sections = GEMMA_4_26B_A4B["cache_sections"]
91
+ total = sum(s.n_layers for s in sections)
92
+ assert total == GEMMA_4_26B_A4B["n_layers"]
tests/test_iswa_fingerprint.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — ISWA Fingerprint Tests
3
+ Tests for per-section Fourier fingerprint computation and concatenation.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import torch
9
+
10
+ from kvcos.core.blob_parser import ParsedKVCache, ParsedMultiSectionCache, parse_multi_section_blob
11
+ from kvcos.core.fingerprint import (
12
+ compute_fourier_fingerprint_v2,
13
+ compute_iswa_fingerprint,
14
+ )
15
+ from kvcos.core.types import AttentionType, CacheSection
16
+ from tests.conftest import GEMMA4_SECTIONS, make_synthetic_iswa_blob
17
+
18
+
19
+ class TestISWAFingerprint:
20
+ """Per-section fingerprint computation for ISWA models."""
21
+
22
+ def _make_parsed(self, n_cells: int = 4) -> ParsedMultiSectionCache:
23
+ blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=n_cells)
24
+ return parse_multi_section_blob(blob, GEMMA4_SECTIONS)
25
+
26
+ def test_fingerprint_shape(self) -> None:
27
+ parsed = self._make_parsed()
28
+ fp = compute_iswa_fingerprint(parsed, freqs=[0, 1])
29
+
30
+ # Global: 2 * 512 * 2 = 2048
31
+ # SWA: 8 * 256 * 2 = 4096
32
+ # Total: 6144
33
+ assert fp.shape == (6144,)
34
+
35
+ def test_fingerprint_dtype(self) -> None:
36
+ parsed = self._make_parsed()
37
+ fp = compute_iswa_fingerprint(parsed)
38
+ assert fp.dtype == torch.float32
39
+
40
+ def test_fingerprint_normalized(self) -> None:
41
+ """Each section's sub-FP is concat of per-freq L2-normalized vectors."""
42
+ parsed = self._make_parsed()
43
+ fp = compute_iswa_fingerprint(parsed, freqs=[0, 1])
44
+
45
+ # Global section FP: first 2048 dims (1024 per freq, 2 freqs)
46
+ global_fp = fp[:2048]
47
+ # SWA section FP: next 4096 dims (2048 per freq, 2 freqs)
48
+ swa_fp = fp[2048:]
49
+
50
+ # Each sub-section is 2 concatenated unit vectors → norm = sqrt(2)
51
+ import math
52
+ expected_norm = math.sqrt(2)
53
+ assert abs(global_fp.norm().item() - expected_norm) < 0.05
54
+ assert abs(swa_fp.norm().item() - expected_norm) < 0.05
55
+
56
+ def test_deterministic(self) -> None:
57
+ parsed = self._make_parsed()
58
+ fp1 = compute_iswa_fingerprint(parsed)
59
+ fp2 = compute_iswa_fingerprint(parsed)
60
+ assert torch.allclose(fp1, fp2)
61
+
62
+ def test_different_inputs_differ(self) -> None:
63
+ p1 = self._make_parsed(n_cells=4)
64
+ blob2 = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4, seed=999)
65
+ p2 = parse_multi_section_blob(blob2, GEMMA4_SECTIONS)
66
+
67
+ fp1 = compute_iswa_fingerprint(p1)
68
+ fp2 = compute_iswa_fingerprint(p2)
69
+ cos = torch.nn.functional.cosine_similarity(fp1.unsqueeze(0), fp2.unsqueeze(0))
70
+ assert cos.item() < 0.99 # different inputs → different FPs
71
+
72
+ def test_single_section_matches_standard(self) -> None:
73
+ """Single-section ISWA FP should match standard FP."""
74
+ section = CacheSection(AttentionType.FULL, 5, 2, 512)
75
+ blob = make_synthetic_iswa_blob((section,), n_cells=4)
76
+ parsed = parse_multi_section_blob(blob, (section,))
77
+
78
+ iswa_fp = compute_iswa_fingerprint(parsed, freqs=[0, 1])
79
+
80
+ # Compare with standard FP on same data
81
+ layer_keys = parsed.sections[0].keys.float().mean(dim=2)
82
+ standard_fp = compute_fourier_fingerprint_v2(layer_keys, freqs=[0, 1])
83
+
84
+ assert torch.allclose(iswa_fp, standard_fp, atol=1e-5)
85
+
86
+ def test_custom_freqs(self) -> None:
87
+ parsed = self._make_parsed()
88
+ fp_f0 = compute_iswa_fingerprint(parsed, freqs=[0])
89
+ fp_f01 = compute_iswa_fingerprint(parsed, freqs=[0, 1])
90
+
91
+ # f0 only: Global(1024) + SWA(2048) = 3072
92
+ assert fp_f0.shape == (3072,)
93
+ # f0+f1: double
94
+ assert fp_f01.shape == (6144,)
tests/test_iswa_types.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — ISWA Type System Tests
3
+ Tests for CacheSection, AttentionType, and ISWA-aware ModelCacheSpec.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import pytest
9
+
10
+ from kvcos.core.types import AttentionType, CacheSection, ModelCacheSpec
11
+
12
+
13
+ class TestAttentionType:
14
+ """AttentionType enum values."""
15
+
16
+ def test_full_value(self) -> None:
17
+ assert AttentionType.FULL == "full"
18
+
19
+ def test_sliding_value(self) -> None:
20
+ assert AttentionType.SLIDING == "sliding"
21
+
22
+ def test_is_str(self) -> None:
23
+ assert isinstance(AttentionType.FULL, str)
24
+
25
+
26
+ class TestCacheSection:
27
+ """CacheSection frozen dataclass."""
28
+
29
+ def test_global_section(self) -> None:
30
+ sec = CacheSection(
31
+ attention_type=AttentionType.FULL,
32
+ n_layers=5,
33
+ n_kv_heads=2,
34
+ head_dim=512,
35
+ )
36
+ assert sec.n_layers == 5
37
+ assert sec.n_kv_heads == 2
38
+ assert sec.head_dim == 512
39
+ assert sec.window_size is None
40
+
41
+ def test_sliding_section(self) -> None:
42
+ sec = CacheSection(
43
+ attention_type=AttentionType.SLIDING,
44
+ n_layers=25,
45
+ n_kv_heads=8,
46
+ head_dim=256,
47
+ window_size=1024,
48
+ )
49
+ assert sec.attention_type == AttentionType.SLIDING
50
+ assert sec.window_size == 1024
51
+
52
+ def test_n_embd_kv(self) -> None:
53
+ sec = CacheSection(
54
+ attention_type=AttentionType.FULL,
55
+ n_layers=5,
56
+ n_kv_heads=2,
57
+ head_dim=512,
58
+ )
59
+ assert sec.n_embd_kv == 1024 # 2 * 512
60
+
61
+ def test_frozen(self) -> None:
62
+ sec = CacheSection(
63
+ attention_type=AttentionType.FULL,
64
+ n_layers=5,
65
+ n_kv_heads=2,
66
+ head_dim=512,
67
+ )
68
+ with pytest.raises(AttributeError):
69
+ sec.n_layers = 10 # type: ignore[misc]
70
+
71
+ def test_equality(self) -> None:
72
+ a = CacheSection(AttentionType.FULL, 5, 2, 512)
73
+ b = CacheSection(AttentionType.FULL, 5, 2, 512)
74
+ assert a == b
75
+
76
+ def test_inequality(self) -> None:
77
+ a = CacheSection(AttentionType.FULL, 5, 2, 512)
78
+ b = CacheSection(AttentionType.SLIDING, 25, 8, 256, 1024)
79
+ assert a != b
80
+
81
+
82
+ class TestModelCacheSpecISWA:
83
+ """ModelCacheSpec with optional cache_sections."""
84
+
85
+ def test_standard_spec_no_sections(self) -> None:
86
+ spec = ModelCacheSpec(
87
+ model_id="meta-llama/Llama-3.1-8B-Instruct",
88
+ model_family="llama",
89
+ n_layers=32,
90
+ n_heads=32,
91
+ n_kv_heads=8,
92
+ head_dim=128,
93
+ rope_enabled=True,
94
+ extraction_layers=tuple(range(8, 32)),
95
+ )
96
+ assert "cache_sections" not in spec
97
+ assert spec["n_kv_heads"] == 8
98
+
99
+ def test_iswa_spec_with_sections(self) -> None:
100
+ sections = (
101
+ CacheSection(AttentionType.FULL, 5, 2, 512),
102
+ CacheSection(AttentionType.SLIDING, 25, 8, 256, 1024),
103
+ )
104
+ spec = ModelCacheSpec(
105
+ model_id="google/gemma-4-26b-a4b-it",
106
+ model_family="gemma",
107
+ n_layers=30,
108
+ n_heads=32,
109
+ n_kv_heads=8,
110
+ head_dim=256,
111
+ rope_enabled=True,
112
+ extraction_layers=tuple(range(8, 30)),
113
+ cache_sections=sections,
114
+ )
115
+ assert "cache_sections" in spec
116
+ assert len(spec["cache_sections"]) == 2
117
+ assert spec["cache_sections"][0].n_embd_kv == 1024
118
+ assert spec["cache_sections"][1].n_embd_kv == 2048
119
+
120
+ def test_iswa_total_layers_match(self) -> None:
121
+ sections = (
122
+ CacheSection(AttentionType.FULL, 5, 2, 512),
123
+ CacheSection(AttentionType.SLIDING, 25, 8, 256, 1024),
124
+ )
125
+ total = sum(s.n_layers for s in sections)
126
+ assert total == 30
127
+
128
+ def test_backward_compat_existing_specs(self) -> None:
129
+ """Existing specs without cache_sections still work."""
130
+ from kvcos.core.cache_spec import LLAMA_3_1_8B
131
+ assert LLAMA_3_1_8B["n_kv_heads"] == 8
132
+ assert "cache_sections" not in LLAMA_3_1_8B
tests/test_knowledge_index.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for kvcos.engram.knowledge_index — HNSW knowledge search."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import pytest
7
+ import torch
8
+
9
+ from kvcos.engram.embedder import get_fingerprint
10
+ from kvcos.engram.format import EigramEncoder
11
+ from kvcos.engram.knowledge_index import KnowledgeIndex
12
+
13
+
14
+ @pytest.fixture
15
+ def knowledge_dir(tmp_path):
16
+ """Create a temporary knowledge directory with test .eng files."""
17
+ encoder = EigramEncoder()
18
+ project_dir = tmp_path / "test_project"
19
+ project_dir.mkdir()
20
+
21
+ docs = [
22
+ ("doc_ml", "Machine learning model training and optimization"),
23
+ ("doc_db", "PostgreSQL database schema migration tools"),
24
+ ("doc_api", "REST API endpoint authentication and authorization"),
25
+ ("doc_test", "Unit testing with pytest fixtures and mocking"),
26
+ ("doc_deploy", "Docker container deployment to Kubernetes cluster"),
27
+ ]
28
+
29
+ for doc_id, text in docs:
30
+ fp, source = get_fingerprint(text)
31
+ dim = fp.shape[0]
32
+
33
+ blob = encoder.encode(
34
+ vec_perdoc=torch.zeros(116),
35
+ vec_fcdb=torch.zeros(116),
36
+ joint_center=torch.zeros(128),
37
+ corpus_hash="test" * 8,
38
+ model_id=source[:16],
39
+ basis_rank=116,
40
+ n_corpus=0,
41
+ layer_range=(0, 0),
42
+ context_len=len(text),
43
+ l2_norm=float(torch.norm(fp).item()),
44
+ scs=0.0,
45
+ margin_proof=0.0,
46
+ task_description=text[:256],
47
+ cache_id=doc_id,
48
+ vec_fourier=fp if dim == 2048 else None,
49
+ vec_fourier_v2=fp,
50
+ confusion_flag=False,
51
+ )
52
+
53
+ eng_path = project_dir / f"{doc_id}.eng"
54
+ eng_path.write_bytes(blob)
55
+
56
+ meta = {
57
+ "cache_id": doc_id,
58
+ "task_description": text,
59
+ "source_path": f"/test/{doc_id}.md",
60
+ "project": "test_project",
61
+ "fp_source": source,
62
+ "chunk_index": 0,
63
+ "chunk_total": 1,
64
+ "headers": [],
65
+ }
66
+ meta_path = Path(str(eng_path) + ".meta.json")
67
+ meta_path.write_text(json.dumps(meta))
68
+
69
+ return tmp_path
70
+
71
+
72
+ class TestKnowledgeIndexBuild:
73
+ def test_build_from_directory(self, knowledge_dir):
74
+ kidx = KnowledgeIndex.build_from_knowledge_dir(
75
+ knowledge_dir, verbose=False
76
+ )
77
+ assert len(kidx) == 5
78
+
79
+ def test_build_empty_directory(self, tmp_path):
80
+ with pytest.raises(ValueError, match="No .eng files"):
81
+ KnowledgeIndex.build_from_knowledge_dir(tmp_path, verbose=False)
82
+
83
+
84
+ class TestKnowledgeIndexSearch:
85
+ def test_search_returns_results(self, knowledge_dir):
86
+ kidx = KnowledgeIndex.build_from_knowledge_dir(
87
+ knowledge_dir, verbose=False
88
+ )
89
+ results = kidx.search("database query optimization", k=3)
90
+ assert len(results) == 3
91
+ assert all(r.score > 0 for r in results)
92
+
93
+ def test_search_result_fields(self, knowledge_dir):
94
+ kidx = KnowledgeIndex.build_from_knowledge_dir(
95
+ knowledge_dir, verbose=False
96
+ )
97
+ results = kidx.search("testing", k=1)
98
+ r = results[0]
99
+ assert r.doc_id
100
+ assert isinstance(r.score, float)
101
+ assert r.rank == 0
102
+ assert r.project == "test_project"
103
+
104
+ def test_search_with_tensor(self, knowledge_dir):
105
+ kidx = KnowledgeIndex.build_from_knowledge_dir(
106
+ knowledge_dir, verbose=False
107
+ )
108
+ query_fp, _ = get_fingerprint("unit tests")
109
+ results = kidx.search(query_fp, k=2)
110
+ assert len(results) == 2
111
+
112
+ def test_search_margin(self, knowledge_dir):
113
+ kidx = KnowledgeIndex.build_from_knowledge_dir(
114
+ knowledge_dir, verbose=False
115
+ )
116
+ results = kidx.search("testing", k=3)
117
+ # Top result should have a margin
118
+ assert results[0].margin >= 0
119
+
120
+
121
+ class TestKnowledgeIndexPersistence:
122
+ def test_save_and_load(self, knowledge_dir, tmp_path):
123
+ kidx = KnowledgeIndex.build_from_knowledge_dir(
124
+ knowledge_dir, verbose=False
125
+ )
126
+ index_dir = tmp_path / "index"
127
+ kidx.save(index_dir)
128
+
129
+ loaded = KnowledgeIndex.load(index_dir)
130
+ assert len(loaded) == len(kidx)
131
+
132
+ # Search should work on loaded index
133
+ results = loaded.search("database", k=2)
134
+ assert len(results) == 2
135
+
136
+ def test_load_nonexistent(self, tmp_path):
137
+ with pytest.raises(FileNotFoundError):
138
+ KnowledgeIndex.load(tmp_path / "nonexistent")
tests/test_manifest.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for kvcos.engram.manifest — knowledge index registry."""
2
+
3
+ import json
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ import pytest
8
+
9
+ from kvcos.engram.manifest import ChunkRecord, Manifest, SourceRecord, _content_hash
10
+
11
+
12
+ @pytest.fixture
13
+ def tmp_manifest(tmp_path):
14
+ """Create a Manifest with a temporary path."""
15
+ return Manifest.load(tmp_path / "manifest.json")
16
+
17
+
18
+ class TestContentHash:
19
+ def test_deterministic(self):
20
+ assert _content_hash("hello") == _content_hash("hello")
21
+
22
+ def test_different_content(self):
23
+ assert _content_hash("hello") != _content_hash("world")
24
+
25
+
26
+ class TestManifestLoad:
27
+ def test_load_nonexistent_creates_empty(self, tmp_path):
28
+ m = Manifest.load(tmp_path / "does_not_exist.json")
29
+ assert m.total_sources == 0
30
+ assert m.total_chunks == 0
31
+
32
+ def test_load_existing(self, tmp_path):
33
+ # Write a manifest, then load it
34
+ m = Manifest.load(tmp_path / "manifest.json")
35
+ m = m.register(
36
+ source_path="/test/file.md",
37
+ content_hash="abc123",
38
+ project="test",
39
+ file_size=100,
40
+ chunks=[ChunkRecord(
41
+ eng_path="/test/file.eng",
42
+ chunk_index=0,
43
+ chunk_total=1,
44
+ char_start=0,
45
+ char_end=100,
46
+ indexed_at=1000.0,
47
+ )],
48
+ )
49
+
50
+ # Load again from disk
51
+ m2 = Manifest.load(tmp_path / "manifest.json")
52
+ assert m2.total_sources == 1
53
+ assert m2.total_chunks == 1
54
+
55
+
56
+ class TestManifestRegister:
57
+ def test_register_new(self, tmp_manifest):
58
+ chunks = [ChunkRecord(
59
+ eng_path="/out/test.eng",
60
+ chunk_index=0,
61
+ chunk_total=1,
62
+ char_start=0,
63
+ char_end=50,
64
+ indexed_at=1000.0,
65
+ )]
66
+ m = tmp_manifest.register(
67
+ source_path="/src/test.md",
68
+ content_hash="hash1",
69
+ project="myproject",
70
+ file_size=50,
71
+ chunks=chunks,
72
+ )
73
+ assert m.total_sources == 1
74
+ assert m.total_chunks == 1
75
+ assert "myproject" in m.projects
76
+
77
+ def test_register_overwrites_existing(self, tmp_manifest):
78
+ chunks1 = [ChunkRecord(
79
+ eng_path="/out/v1.eng", chunk_index=0, chunk_total=1,
80
+ char_start=0, char_end=50, indexed_at=1000.0,
81
+ )]
82
+ m = tmp_manifest.register(
83
+ "/src/test.md", "hash1", "proj", 50, chunks1,
84
+ )
85
+ assert m.total_chunks == 1
86
+
87
+ chunks2 = [
88
+ ChunkRecord("/out/v2_1.eng", 0, 2, 0, 25, 2000.0),
89
+ ChunkRecord("/out/v2_2.eng", 1, 2, 25, 50, 2000.0),
90
+ ]
91
+ m = m.register("/src/test.md", "hash2", "proj", 50, chunks2)
92
+ assert m.total_sources == 1 # still 1 source
93
+ assert m.total_chunks == 2 # now 2 chunks
94
+
95
+ def test_register_returns_new_manifest(self, tmp_manifest):
96
+ """Register returns a new Manifest (immutability)."""
97
+ m1 = tmp_manifest
98
+ m2 = m1.register("/src/a.md", "h", "p", 10, [])
99
+ assert m1.total_sources == 0 # original unchanged
100
+ assert m2.total_sources == 1
101
+
102
+
103
+ class TestManifestNeedsReindex:
104
+ def test_unknown_file_needs_index(self, tmp_manifest):
105
+ assert tmp_manifest.needs_reindex("/new/file.md", "any_hash")
106
+
107
+ def test_same_hash_no_reindex(self, tmp_manifest):
108
+ m = tmp_manifest.register("/src/a.md", "hash1", "p", 10, [])
109
+ assert not m.needs_reindex("/src/a.md", "hash1")
110
+
111
+ def test_different_hash_needs_reindex(self, tmp_manifest):
112
+ m = tmp_manifest.register("/src/a.md", "hash1", "p", 10, [])
113
+ assert m.needs_reindex("/src/a.md", "hash2")
114
+
115
+
116
+ class TestManifestUnregister:
117
+ def test_unregister_existing(self, tmp_manifest):
118
+ m = tmp_manifest.register("/src/a.md", "h", "p", 10, [])
119
+ m = m.unregister("/src/a.md")
120
+ assert m.total_sources == 0
121
+
122
+ def test_unregister_nonexistent(self, tmp_manifest):
123
+ m = tmp_manifest.unregister("/not/here.md")
124
+ assert m.total_sources == 0
125
+
126
+
127
+ class TestManifestQueries:
128
+ def test_get_project_records(self, tmp_manifest):
129
+ m = tmp_manifest
130
+ m = m.register("/a.md", "h1", "proj_a", 10, [])
131
+ m = m.register("/b.md", "h2", "proj_b", 20, [])
132
+ m = m.register("/c.md", "h3", "proj_a", 30, [])
133
+
134
+ a_recs = m.get_project_records("proj_a")
135
+ assert len(a_recs) == 2
136
+
137
+ def test_summary(self, tmp_manifest):
138
+ m = tmp_manifest.register("/a.md", "h", "p", 10, [
139
+ ChunkRecord("/a.eng", 0, 1, 0, 10, 1000.0),
140
+ ])
141
+ s = m.summary()
142
+ assert s["total_sources"] == 1
143
+ assert s["total_chunks"] == 1
144
+ assert "p" in s["projects"]
145
+
146
+ def test_contains(self, tmp_manifest):
147
+ m = tmp_manifest.register("/a.md", "h", "p", 10, [])
148
+ assert "/a.md" in m
149
+ assert "/b.md" not in m
150
+
151
+ def test_len(self, tmp_manifest):
152
+ m = tmp_manifest.register("/a.md", "h", "p", 10, [])
153
+ assert len(m) == 1
154
+
155
+
156
+ class TestManifestPersistence:
157
+ def test_atomic_write(self, tmp_path):
158
+ m = Manifest.load(tmp_path / "manifest.json")
159
+ m = m.register("/a.md", "h", "p", 10, [])
160
+
161
+ # File should exist
162
+ assert (tmp_path / "manifest.json").exists()
163
+
164
+ # Content should be valid JSON
165
+ data = json.loads((tmp_path / "manifest.json").read_text())
166
+ assert data["version"] == 1
167
+ assert len(data["sources"]) == 1
tests/test_manifold_index.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — Manifold Index Tests
3
+ Tests for FAISS IndexFlatIP add/search/remove/persist (D2, D4).
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import pytest
12
+ import torch
13
+
14
+ from kvcos.core.manifold_index import IndexEntry, ManifoldIndex
15
+
16
+
17
+ def _entry(cid: str = "c1", model: str = "llama") -> IndexEntry:
18
+ return IndexEntry(
19
+ cache_id=cid, task_description="test",
20
+ model_id=model, created_at="2026-01-01T00:00:00Z",
21
+ context_len=256, l2_norm=1.0,
22
+ )
23
+
24
+
25
+ class TestAddAndSearch:
26
+ """Add vectors, search via MIPS."""
27
+
28
+ def test_add_increments(self) -> None:
29
+ idx = ManifoldIndex(dim=8)
30
+ idx.add(torch.randn(8), _entry("a"))
31
+ idx.add(torch.randn(8), _entry("b"))
32
+ assert idx.n_entries == 2
33
+
34
+ def test_search_returns_correct_order(self) -> None:
35
+ idx = ManifoldIndex(dim=4)
36
+ v1 = torch.tensor([1.0, 0.0, 0.0, 0.0])
37
+ v2 = torch.tensor([0.0, 1.0, 0.0, 0.0])
38
+ idx.add(v1, _entry("close"))
39
+ idx.add(v2, _entry("far"))
40
+
41
+ query = torch.tensor([1.0, 0.0, 0.0, 0.0])
42
+ results = idx.search(query, top_k=2)
43
+ assert results[0]["cache_id"] == "close"
44
+ assert results[0]["similarity"] > results[1]["similarity"]
45
+
46
+ def test_search_empty_returns_empty(self) -> None:
47
+ idx = ManifoldIndex(dim=4)
48
+ results = idx.search(torch.randn(4), top_k=5)
49
+ assert results == []
50
+
51
+ def test_model_filter(self) -> None:
52
+ idx = ManifoldIndex(dim=4)
53
+ idx.add(torch.randn(4), _entry("a", model="llama"))
54
+ idx.add(torch.randn(4), _entry("b", model="phi"))
55
+ results = idx.search(torch.randn(4), top_k=10, model_id="phi")
56
+ assert all(r["model_id"] == "phi" for r in results)
57
+
58
+
59
+ class TestRemoveAndRebuild:
60
+ """Remove entries and rebuild index."""
61
+
62
+ def test_remove_hides_from_search(self) -> None:
63
+ idx = ManifoldIndex(dim=4)
64
+ v = torch.tensor([1.0, 0.0, 0.0, 0.0])
65
+ idx.add(v, _entry("target"))
66
+ assert idx.remove("target")
67
+ results = idx.search(v, top_k=1)
68
+ assert len(results) == 0
69
+
70
+ def test_rebuild_compacts(self) -> None:
71
+ idx = ManifoldIndex(dim=4)
72
+ for i in range(5):
73
+ idx.add(torch.randn(4), _entry(f"c{i}"))
74
+ idx.remove("c1")
75
+ idx.remove("c3")
76
+ active = idx.rebuild()
77
+ assert active == 3
78
+
79
+
80
+ class TestPersistence:
81
+ """Save/load round-trip (D2: serialize_index/deserialize_index)."""
82
+
83
+ def test_save_load_round_trip(self, tmp_index_dir: Path) -> None:
84
+ idx = ManifoldIndex(dim=4)
85
+ v1 = torch.tensor([1.0, 0.0, 0.0, 0.0])
86
+ idx.add(v1, _entry("persisted"))
87
+ idx.save(tmp_index_dir / "test.faiss")
88
+
89
+ idx2 = ManifoldIndex(dim=4, index_path=tmp_index_dir / "test.faiss")
90
+ assert idx2.n_entries == 1
91
+ results = idx2.search(v1, top_k=1)
92
+ assert results[0]["cache_id"] == "persisted"
93
+
94
+ def test_dim_mismatch_raises(self) -> None:
95
+ idx = ManifoldIndex(dim=4)
96
+ with pytest.raises(ValueError, match="dim"):
97
+ idx.add(torch.randn(8), _entry("wrong"))
tests/test_retriever.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — Retriever Tests
3
+ Tests for EGRRetriever: store → index → query → retrieve pipeline.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from pathlib import Path
9
+
10
+ import torch
11
+
12
+ from kvcos.core.cache_spec import LLAMA_3_1_8B
13
+ from kvcos.core.serializer import EngramSerializer
14
+ from kvcos.core.types import CompressionMethod, StateExtractionMode
15
+ from kvcos.core.manifold_index import ManifoldIndex
16
+ from kvcos.core.retriever import EGRRetriever, RetrievalResponse
17
+ from kvcos.core.state_extractor import MARStateExtractor
18
+ from kvcos.storage.local import LocalStorageBackend
19
+ from tests.conftest import make_synthetic_kv
20
+
21
+
22
+ def _build_retriever(
23
+ data_dir: Path, mode: StateExtractionMode = StateExtractionMode.MEAN_POOL,
24
+ ) -> EGRRetriever:
25
+ ext = MARStateExtractor(mode=mode, rank=128)
26
+ dim = ext.output_dim(LLAMA_3_1_8B)
27
+ idx = ManifoldIndex(dim=dim)
28
+ storage = LocalStorageBackend(data_dir=data_dir)
29
+ return EGRRetriever(ext, idx, storage)
30
+
31
+
32
+ class TestIndexAndRetrieve:
33
+ """Full store → search → load pipeline."""
34
+
35
+ def test_index_returns_cache_id(self, tmp_data_dir: Path) -> None:
36
+ keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
37
+ retriever = _build_retriever(tmp_data_dir)
38
+
39
+ cid = retriever.index_engram(
40
+ keys=keys, values=values, spec=LLAMA_3_1_8B,
41
+ agent_id="test", task_description="test engram",
42
+ model_id=LLAMA_3_1_8B["model_id"],
43
+ output_dir=tmp_data_dir,
44
+ )
45
+ assert isinstance(cid, str)
46
+ assert len(cid) > 0
47
+
48
+ def test_retrieve_finds_stored(self, tmp_data_dir: Path) -> None:
49
+ keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
50
+ retriever = _build_retriever(tmp_data_dir)
51
+
52
+ retriever.index_engram(
53
+ keys=keys, values=values, spec=LLAMA_3_1_8B,
54
+ agent_id="test", task_description="findable engram",
55
+ model_id=LLAMA_3_1_8B["model_id"],
56
+ output_dir=tmp_data_dir,
57
+ )
58
+
59
+ query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64, seed=99)
60
+ response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=1)
61
+
62
+ assert isinstance(response, RetrievalResponse)
63
+ assert len(response.results) == 1
64
+ assert response.results[0].keys.shape == keys.shape
65
+
66
+ def test_retrieve_empty_index(self, tmp_data_dir: Path) -> None:
67
+ retriever = _build_retriever(tmp_data_dir)
68
+ query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
69
+ response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=5)
70
+ assert len(response.results) == 0
71
+
72
+ def test_delete_removes(self, tmp_data_dir: Path) -> None:
73
+ keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
74
+ retriever = _build_retriever(tmp_data_dir)
75
+
76
+ cid = retriever.index_engram(
77
+ keys=keys, values=values, spec=LLAMA_3_1_8B,
78
+ agent_id="test", task_description="deletable",
79
+ model_id=LLAMA_3_1_8B["model_id"],
80
+ output_dir=tmp_data_dir,
81
+ )
82
+ assert retriever.delete_engram(cid)
83
+
84
+ query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
85
+ response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=5)
86
+ assert len(response.results) == 0
tests/test_serializer.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — Serializer Tests
3
+ Tests for .eng safetensors serialize/deserialize round-trip (D7).
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from pathlib import Path
9
+
10
+ import pytest
11
+ import torch
12
+ from safetensors.torch import load_file
13
+
14
+ from kvcos.core.serializer import EngramSerializer, SerializationError
15
+ from kvcos.core.types import CompressionMethod
16
+ from tests.conftest import make_synthetic_kv
17
+ from kvcos.core.cache_spec import LLAMA_3_1_8B
18
+
19
+
20
+ class TestSerializeRoundTrip:
21
+ """Serialize → deserialize preserves shape, dtype, metadata."""
22
+
23
+ def test_round_trip_shape(self, tmp_data_dir: Path) -> None:
24
+ keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256)
25
+ s = EngramSerializer()
26
+ eng = tmp_data_dir / "test.eng"
27
+
28
+ s.serialize(
29
+ keys=keys, values=values,
30
+ agent_id="test-agent", task_description="unit test",
31
+ model_id=LLAMA_3_1_8B["model_id"], output_path=eng,
32
+ compression=CompressionMethod.FP16,
33
+ )
34
+ k_out, v_out, meta = s.deserialize(eng)
35
+
36
+ assert k_out.shape == keys.shape
37
+ assert v_out.shape == values.shape
38
+
39
+ def test_metadata_fields(self, tmp_data_dir: Path) -> None:
40
+ keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
41
+ s = EngramSerializer()
42
+ eng = tmp_data_dir / "meta.eng"
43
+
44
+ s.serialize(
45
+ keys=keys, values=values,
46
+ agent_id="agent-42", task_description="metadata check",
47
+ model_id=LLAMA_3_1_8B["model_id"], output_path=eng,
48
+ compression=CompressionMethod.Q8_0,
49
+ )
50
+ _, _, meta = s.deserialize(eng)
51
+
52
+ assert meta["agent_id"] == "agent-42"
53
+ assert meta["task_description"] == "metadata check"
54
+ assert meta["compression"] == "q8_0"
55
+ assert meta["n_layers"] == "32"
56
+ assert meta["model_family"] == "llama"
57
+
58
+ def test_safetensors_loadable(self, tmp_data_dir: Path) -> None:
59
+ """D7: File must be valid safetensors."""
60
+ keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
61
+ s = EngramSerializer()
62
+ eng = tmp_data_dir / "valid.eng"
63
+
64
+ s.serialize(
65
+ keys=keys, values=values,
66
+ agent_id="test", task_description="safetensors check",
67
+ model_id=LLAMA_3_1_8B["model_id"], output_path=eng,
68
+ compression=CompressionMethod.FP16,
69
+ )
70
+ tensors = load_file(str(eng))
71
+ assert "layer_0_keys" in tensors
72
+ assert "layer_0_values" in tensors
73
+
74
+ def test_result_dict(self, tmp_data_dir: Path) -> None:
75
+ keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
76
+ s = EngramSerializer()
77
+ eng = tmp_data_dir / "result.eng"
78
+
79
+ result = s.serialize(
80
+ keys=keys, values=values,
81
+ agent_id="test", task_description="result check",
82
+ model_id=LLAMA_3_1_8B["model_id"], output_path=eng,
83
+ )
84
+ assert "cache_id" in result
85
+ assert result["size_bytes"] > 0
86
+ assert result["n_layers"] == 32
87
+
88
+
89
+ class TestSerializerErrors:
90
+ """Edge cases and error handling."""
91
+
92
+ def test_shape_mismatch_raises(self, tmp_data_dir: Path) -> None:
93
+ keys = torch.randn(32, 8, 64, 128, dtype=torch.float16)
94
+ values = torch.randn(32, 8, 32, 128, dtype=torch.float16)
95
+ s = EngramSerializer()
96
+
97
+ with pytest.raises(SerializationError, match="mismatch"):
98
+ s.serialize(
99
+ keys=keys, values=values,
100
+ agent_id="t", task_description="t",
101
+ model_id="test", output_path=tmp_data_dir / "bad.eng",
102
+ )
103
+
104
+ def test_3d_tensor_raises(self, tmp_data_dir: Path) -> None:
105
+ keys = torch.randn(8, 64, 128, dtype=torch.float16)
106
+ s = EngramSerializer()
107
+
108
+ with pytest.raises(SerializationError, match="4D"):
109
+ s.serialize(
110
+ keys=keys, values=keys,
111
+ agent_id="t", task_description="t",
112
+ model_id="test", output_path=tmp_data_dir / "bad.eng",
113
+ )
114
+
115
+ def test_missing_file_raises(self, tmp_data_dir: Path) -> None:
116
+ s = EngramSerializer()
117
+ with pytest.raises(SerializationError, match="not found"):
118
+ s.deserialize(tmp_data_dir / "nonexistent.eng")
tests/test_state_extractor.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — State Extractor Tests
3
+ Tests for all 3 EGR extraction modes (D3).
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import torch
9
+
10
+ from kvcos.core.cache_spec import LLAMA_3_1_8B, PHI_3_MINI
11
+ from kvcos.core.types import StateExtractionMode
12
+ from kvcos.core.state_extractor import MARStateExtractor
13
+ from tests.conftest import make_synthetic_kv
14
+
15
+
16
+ class TestMeanPool:
17
+ """mean_pool: mean over layers, heads, context → [head_dim]."""
18
+
19
+ def test_output_dim(self) -> None:
20
+ keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
21
+ ext = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
22
+ result = ext.extract(keys, LLAMA_3_1_8B)
23
+ assert result.state_vec.shape == (128,)
24
+
25
+ def test_output_dim_api(self) -> None:
26
+ ext = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
27
+ assert ext.output_dim(LLAMA_3_1_8B) == 128
28
+
29
+ def test_l2_norm_positive(self) -> None:
30
+ keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
31
+ ext = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
32
+ result = ext.extract(keys, LLAMA_3_1_8B)
33
+ assert result.l2_norm > 0
34
+
35
+ def test_deterministic(self) -> None:
36
+ keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
37
+ ext = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
38
+ r1 = ext.extract(keys, LLAMA_3_1_8B)
39
+ r2 = ext.extract(keys, LLAMA_3_1_8B)
40
+ assert torch.equal(r1.state_vec, r2.state_vec)
41
+
42
+
43
+ class TestSVDProject:
44
+ """svd_project: truncated SVD, rank-160 → [rank]."""
45
+
46
+ def test_output_dim(self) -> None:
47
+ keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
48
+ ext = MARStateExtractor(mode=StateExtractionMode.SVD_PROJECT, rank=160)
49
+ result = ext.extract(keys, LLAMA_3_1_8B)
50
+ assert result.state_vec.shape == (128,) # clamped to head_dim
51
+
52
+ def test_projection_stored(self) -> None:
53
+ keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
54
+ ext = MARStateExtractor(mode=StateExtractionMode.SVD_PROJECT, rank=160)
55
+ ext.extract(keys, LLAMA_3_1_8B)
56
+ proj = ext.last_projection
57
+ assert proj is not None
58
+ assert 0.0 < proj.explained_variance_ratio <= 1.0
59
+
60
+ def test_n_layers_used(self) -> None:
61
+ keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
62
+ ext = MARStateExtractor(mode=StateExtractionMode.SVD_PROJECT)
63
+ result = ext.extract(keys, LLAMA_3_1_8B)
64
+ assert result.n_layers_used == 24 # layers 8-31
65
+
66
+
67
+ class TestXKVProject:
68
+ """xkv_project: grouped cross-layer SVD."""
69
+
70
+ def test_output_dim(self) -> None:
71
+ keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
72
+ ext = MARStateExtractor(mode=StateExtractionMode.XKV_PROJECT, rank=160)
73
+ result = ext.extract(keys, LLAMA_3_1_8B)
74
+ expected_dim = ext.output_dim(LLAMA_3_1_8B)
75
+ assert result.state_vec.shape == (expected_dim,)
76
+
77
+ def test_different_from_mean_pool(self) -> None:
78
+ keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
79
+ ext_mp = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
80
+ ext_xkv = MARStateExtractor(mode=StateExtractionMode.XKV_PROJECT)
81
+ r_mp = ext_mp.extract(keys, LLAMA_3_1_8B)
82
+ r_xkv = ext_xkv.extract(keys, LLAMA_3_1_8B)
83
+ assert r_mp.state_vec.shape != r_xkv.state_vec.shape
84
+
85
+ def test_phi3_works(self) -> None:
86
+ keys, _ = make_synthetic_kv(PHI_3_MINI, ctx_len=64)
87
+ ext = MARStateExtractor(mode=StateExtractionMode.XKV_PROJECT, rank=96)
88
+ result = ext.extract(keys, PHI_3_MINI)
89
+ assert result.state_vec.dim() == 1
90
+ assert result.state_vec.shape[0] > 0