test: upload 220 tests
Browse files- tests/__init__.py +0 -0
- tests/conftest.py +201 -0
- tests/test_blob_parser.py +125 -0
- tests/test_block_pool.py +92 -0
- tests/test_chunker.py +99 -0
- tests/test_compression.py +356 -0
- tests/test_eigengram.py +152 -0
- tests/test_embedder.py +106 -0
- tests/test_hnsw_index.py +108 -0
- tests/test_integration_synthetic.py +160 -0
- tests/test_iswa_blob_parser.py +136 -0
- tests/test_iswa_bridge.py +92 -0
- tests/test_iswa_fingerprint.py +94 -0
- tests/test_iswa_types.py +132 -0
- tests/test_knowledge_index.py +138 -0
- tests/test_manifest.py +167 -0
- tests/test_manifold_index.py +97 -0
- tests/test_retriever.py +86 -0
- tests/test_serializer.py +118 -0
- tests/test_state_extractor.py +90 -0
tests/__init__.py
ADDED
|
File without changes
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — Test Fixtures
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
Shared pytest fixtures for all test modules.
|
| 6 |
+
Provides synthetic KV cache tensors at correct shapes,
|
| 7 |
+
temp directories, and model specs.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import pytest
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
from kvcos.core.cache_spec import GEMMA_4_26B_A4B, LLAMA_3_1_8B, PHI_3_MINI
|
| 18 |
+
from kvcos.core.types import AttentionType, CacheSection, ModelCacheSpec
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@pytest.fixture
|
| 22 |
+
def llama_spec() -> ModelCacheSpec:
|
| 23 |
+
"""Llama 3.1 8B model spec."""
|
| 24 |
+
return LLAMA_3_1_8B
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@pytest.fixture
|
| 28 |
+
def phi3_spec() -> ModelCacheSpec:
|
| 29 |
+
"""Phi-3-Mini model spec."""
|
| 30 |
+
return PHI_3_MINI
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@pytest.fixture
|
| 34 |
+
def gemma4_spec() -> ModelCacheSpec:
|
| 35 |
+
"""Gemma 4 26B-A4B ISWA model spec."""
|
| 36 |
+
return GEMMA_4_26B_A4B
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@pytest.fixture
|
| 40 |
+
def tmp_data_dir(tmp_path: Path) -> Path:
|
| 41 |
+
"""Temporary data directory for storage tests."""
|
| 42 |
+
data_dir = tmp_path / "engram_data"
|
| 43 |
+
data_dir.mkdir()
|
| 44 |
+
return data_dir
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@pytest.fixture
|
| 48 |
+
def tmp_index_dir(tmp_path: Path) -> Path:
|
| 49 |
+
"""Temporary directory for FAISS index persistence tests."""
|
| 50 |
+
index_dir = tmp_path / "engram_index"
|
| 51 |
+
index_dir.mkdir()
|
| 52 |
+
return index_dir
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def make_synthetic_kv(
|
| 56 |
+
spec: ModelCacheSpec,
|
| 57 |
+
ctx_len: int = 256,
|
| 58 |
+
seed: int = 42,
|
| 59 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 60 |
+
"""Create synthetic KV cache tensors with correct shapes.
|
| 61 |
+
|
| 62 |
+
Returns (keys, values) each [n_layers, n_kv_heads, ctx_len, head_dim].
|
| 63 |
+
Values are random but reproducible via seed.
|
| 64 |
+
"""
|
| 65 |
+
torch.manual_seed(seed)
|
| 66 |
+
shape = (spec["n_layers"], spec["n_kv_heads"], ctx_len, spec["head_dim"])
|
| 67 |
+
keys = torch.randn(shape, dtype=torch.float16)
|
| 68 |
+
values = torch.randn(shape, dtype=torch.float16)
|
| 69 |
+
return keys, values
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@pytest.fixture
|
| 73 |
+
def llama_kv_256(llama_spec: ModelCacheSpec) -> tuple[torch.Tensor, torch.Tensor]:
|
| 74 |
+
"""Synthetic Llama 3.1 8B KV cache, 256 tokens.
|
| 75 |
+
|
| 76 |
+
Shape: [32, 8, 256, 128] for both keys and values.
|
| 77 |
+
"""
|
| 78 |
+
return make_synthetic_kv(llama_spec, ctx_len=256)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@pytest.fixture
|
| 82 |
+
def llama_kv_1024(llama_spec: ModelCacheSpec) -> tuple[torch.Tensor, torch.Tensor]:
|
| 83 |
+
"""Synthetic Llama 3.1 8B KV cache, 1024 tokens."""
|
| 84 |
+
return make_synthetic_kv(llama_spec, ctx_len=1024, seed=123)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@pytest.fixture
|
| 88 |
+
def phi3_kv_256(phi3_spec: ModelCacheSpec) -> tuple[torch.Tensor, torch.Tensor]:
|
| 89 |
+
"""Synthetic Phi-3-Mini KV cache, 256 tokens.
|
| 90 |
+
|
| 91 |
+
Shape: [32, 32, 256, 96] for both keys and values.
|
| 92 |
+
"""
|
| 93 |
+
return make_synthetic_kv(phi3_spec, ctx_len=256, seed=99)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ── ISWA Fixtures ────────────────────────────────────────────────────────────
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def make_synthetic_iswa_blob(
|
| 100 |
+
sections: tuple[CacheSection, ...],
|
| 101 |
+
n_cells: int = 4,
|
| 102 |
+
arch: str = "gemma4",
|
| 103 |
+
v_trans: bool = True,
|
| 104 |
+
seed: int = 42,
|
| 105 |
+
) -> bytes:
|
| 106 |
+
"""Build a synthetic ISWA blob with multiple KV cache sections.
|
| 107 |
+
|
| 108 |
+
Matches llama.cpp state blob format for ISWA models:
|
| 109 |
+
1. Architecture string header
|
| 110 |
+
2. n_stream = len(sections)
|
| 111 |
+
3. Per stream: cell metadata + K/V data per layer
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
sections: Cache sections (e.g., global + SWA for Gemma 4).
|
| 115 |
+
n_cells: Number of KV cells per section.
|
| 116 |
+
arch: Architecture string in blob header.
|
| 117 |
+
v_trans: Whether V tensors are stored transposed.
|
| 118 |
+
seed: Random seed for reproducible data.
|
| 119 |
+
"""
|
| 120 |
+
import struct
|
| 121 |
+
|
| 122 |
+
import numpy as np
|
| 123 |
+
|
| 124 |
+
from kvcos.core.blob_parser import GGML_TYPE_F16
|
| 125 |
+
|
| 126 |
+
rng = np.random.RandomState(seed)
|
| 127 |
+
parts: list[bytes] = []
|
| 128 |
+
|
| 129 |
+
# 1. Architecture string header
|
| 130 |
+
parts.append(struct.pack("<I", len(arch)))
|
| 131 |
+
parts.append(arch.encode("ascii"))
|
| 132 |
+
|
| 133 |
+
# 2. Stream count = number of cache sections
|
| 134 |
+
parts.append(struct.pack("<I", len(sections)))
|
| 135 |
+
|
| 136 |
+
# 3. Per-stream data
|
| 137 |
+
for section in sections:
|
| 138 |
+
n_embd_kv = section.n_kv_heads * section.head_dim
|
| 139 |
+
row_size = n_embd_kv * 2 # fp16
|
| 140 |
+
|
| 141 |
+
# Cell metadata
|
| 142 |
+
parts.append(struct.pack("<I", n_cells))
|
| 143 |
+
for i in range(n_cells):
|
| 144 |
+
parts.append(struct.pack("<i", i)) # pos
|
| 145 |
+
parts.append(struct.pack("<I", 1)) # n_seq_id = 1
|
| 146 |
+
parts.append(struct.pack("<i", 0)) # seq_id = 0
|
| 147 |
+
|
| 148 |
+
# Data section header
|
| 149 |
+
parts.append(struct.pack("<I", 1 if v_trans else 0))
|
| 150 |
+
parts.append(struct.pack("<I", section.n_layers))
|
| 151 |
+
|
| 152 |
+
# K layers
|
| 153 |
+
for _ in range(section.n_layers):
|
| 154 |
+
parts.append(struct.pack("<i", GGML_TYPE_F16))
|
| 155 |
+
parts.append(struct.pack("<Q", row_size))
|
| 156 |
+
data = rng.randn(n_cells * n_embd_kv).astype(np.float16)
|
| 157 |
+
parts.append(data.tobytes())
|
| 158 |
+
|
| 159 |
+
# V layers
|
| 160 |
+
for _ in range(section.n_layers):
|
| 161 |
+
parts.append(struct.pack("<i", GGML_TYPE_F16))
|
| 162 |
+
if v_trans:
|
| 163 |
+
parts.append(struct.pack("<I", 2)) # el_size (fp16)
|
| 164 |
+
parts.append(struct.pack("<I", n_embd_kv)) # n_embd_v_gqa
|
| 165 |
+
else:
|
| 166 |
+
parts.append(struct.pack("<Q", row_size))
|
| 167 |
+
data = rng.randn(n_cells * n_embd_kv).astype(np.float16)
|
| 168 |
+
parts.append(data.tobytes())
|
| 169 |
+
|
| 170 |
+
return b"".join(parts)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# Gemma 4 ISWA section constants (reverse-engineered)
|
| 174 |
+
GEMMA4_GLOBAL_SECTION = CacheSection(
|
| 175 |
+
attention_type=AttentionType.FULL,
|
| 176 |
+
n_layers=5,
|
| 177 |
+
n_kv_heads=2,
|
| 178 |
+
head_dim=512,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
GEMMA4_SWA_SECTION = CacheSection(
|
| 182 |
+
attention_type=AttentionType.SLIDING,
|
| 183 |
+
n_layers=25,
|
| 184 |
+
n_kv_heads=8,
|
| 185 |
+
head_dim=256,
|
| 186 |
+
window_size=1024,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
GEMMA4_SECTIONS = (GEMMA4_GLOBAL_SECTION, GEMMA4_SWA_SECTION)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@pytest.fixture
|
| 193 |
+
def gemma4_iswa_blob() -> bytes:
|
| 194 |
+
"""Synthetic Gemma 4 ISWA blob with 2 sections, 4 cells."""
|
| 195 |
+
return make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@pytest.fixture
|
| 199 |
+
def gemma4_iswa_blob_8cells() -> bytes:
|
| 200 |
+
"""Synthetic Gemma 4 ISWA blob with 2 sections, 8 cells."""
|
| 201 |
+
return make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=8, seed=99)
|
tests/test_blob_parser.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — Blob Parser Tests
|
| 3 |
+
Tests for llama.cpp state blob → structured tensors (D1).
|
| 4 |
+
Uses synthetic blobs matching the real llama_state_get_data() format.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import struct
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pytest
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from kvcos.core.blob_parser import (
|
| 16 |
+
GGML_TYPE_F16,
|
| 17 |
+
BlobParseError,
|
| 18 |
+
ParsedKVCache,
|
| 19 |
+
parse_state_blob,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _make_blob(
|
| 24 |
+
n_cells: int,
|
| 25 |
+
n_layers: int,
|
| 26 |
+
n_kv_heads: int,
|
| 27 |
+
head_dim: int,
|
| 28 |
+
arch: str = "llama",
|
| 29 |
+
v_trans: bool = True,
|
| 30 |
+
) -> bytes:
|
| 31 |
+
"""Build a synthetic blob matching llama_state_get_data() format."""
|
| 32 |
+
parts: list[bytes] = []
|
| 33 |
+
|
| 34 |
+
# 1. Architecture string header
|
| 35 |
+
parts.append(struct.pack("<I", len(arch)))
|
| 36 |
+
parts.append(arch.encode("ascii"))
|
| 37 |
+
|
| 38 |
+
# 2. KV stream header
|
| 39 |
+
parts.append(struct.pack("<I", 1)) # n_stream = 1
|
| 40 |
+
parts.append(struct.pack("<I", n_cells)) # cell_count
|
| 41 |
+
|
| 42 |
+
# 3. Cell metadata: (pos:i32, n_seq:u32, seq_id:i32) per cell
|
| 43 |
+
for i in range(n_cells):
|
| 44 |
+
parts.append(struct.pack("<i", i)) # pos
|
| 45 |
+
parts.append(struct.pack("<I", 1)) # n_seq_id = 1
|
| 46 |
+
parts.append(struct.pack("<i", 0)) # seq_id = 0
|
| 47 |
+
|
| 48 |
+
# 4. Data section header
|
| 49 |
+
parts.append(struct.pack("<I", 1 if v_trans else 0)) # v_trans
|
| 50 |
+
parts.append(struct.pack("<I", n_layers))
|
| 51 |
+
|
| 52 |
+
n_embd_kv = n_kv_heads * head_dim
|
| 53 |
+
row_size = n_embd_kv * 2 # fp16
|
| 54 |
+
|
| 55 |
+
# 5. K layers
|
| 56 |
+
for _ in range(n_layers):
|
| 57 |
+
parts.append(struct.pack("<i", GGML_TYPE_F16)) # type_k
|
| 58 |
+
parts.append(struct.pack("<Q", row_size)) # row_size_k
|
| 59 |
+
data = np.random.randn(n_cells * n_embd_kv).astype(np.float16)
|
| 60 |
+
parts.append(data.tobytes())
|
| 61 |
+
|
| 62 |
+
# 6. V layers
|
| 63 |
+
for _ in range(n_layers):
|
| 64 |
+
parts.append(struct.pack("<i", GGML_TYPE_F16)) # type_v
|
| 65 |
+
if v_trans:
|
| 66 |
+
parts.append(struct.pack("<I", 2)) # el_size (fp16)
|
| 67 |
+
parts.append(struct.pack("<I", n_embd_kv)) # n_embd_v_gqa
|
| 68 |
+
else:
|
| 69 |
+
parts.append(struct.pack("<Q", row_size)) # row_size_v
|
| 70 |
+
data = np.random.randn(n_cells * n_embd_kv).astype(np.float16)
|
| 71 |
+
parts.append(data.tobytes())
|
| 72 |
+
|
| 73 |
+
return b"".join(parts)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class TestBlobParser:
|
| 77 |
+
"""Parse synthetic blobs in real llama_state_get_data format."""
|
| 78 |
+
|
| 79 |
+
def test_parse_shape(self) -> None:
|
| 80 |
+
blob = _make_blob(16, 32, 8, 128)
|
| 81 |
+
result = parse_state_blob(blob, n_kv_heads=8, head_dim=128)
|
| 82 |
+
assert result.keys.shape == (32, 8, 16, 128)
|
| 83 |
+
assert result.values.shape == (32, 8, 16, 128)
|
| 84 |
+
|
| 85 |
+
def test_parse_metadata(self) -> None:
|
| 86 |
+
blob = _make_blob(8, 32, 8, 128)
|
| 87 |
+
result = parse_state_blob(blob, n_kv_heads=8, head_dim=128)
|
| 88 |
+
assert result.n_cells == 8
|
| 89 |
+
assert result.n_layers == 32
|
| 90 |
+
assert result.arch == "llama"
|
| 91 |
+
assert result.v_trans is True
|
| 92 |
+
assert len(result.cells) == 8
|
| 93 |
+
assert result.cells[0].pos == 0
|
| 94 |
+
assert result.cells[7].pos == 7
|
| 95 |
+
|
| 96 |
+
def test_dtype_float16(self) -> None:
|
| 97 |
+
blob = _make_blob(4, 28, 8, 128)
|
| 98 |
+
result = parse_state_blob(blob, n_kv_heads=8, head_dim=128)
|
| 99 |
+
assert result.keys.dtype == torch.float16
|
| 100 |
+
assert result.values.dtype == torch.float16
|
| 101 |
+
|
| 102 |
+
def test_non_transposed_v(self) -> None:
|
| 103 |
+
blob = _make_blob(4, 28, 8, 128, v_trans=False)
|
| 104 |
+
result = parse_state_blob(blob, n_kv_heads=8, head_dim=128)
|
| 105 |
+
assert result.values.shape == (28, 8, 4, 128)
|
| 106 |
+
assert result.v_trans is False
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class TestBlobParserErrors:
|
| 110 |
+
"""Edge cases."""
|
| 111 |
+
|
| 112 |
+
def test_zero_cells_raises(self) -> None:
|
| 113 |
+
blob = struct.pack("<I", 5) + b"llama" + struct.pack("<II", 1, 0) + b"\x00" * 20
|
| 114 |
+
with pytest.raises(BlobParseError, match="0 cells"):
|
| 115 |
+
parse_state_blob(blob, n_kv_heads=8, head_dim=128)
|
| 116 |
+
|
| 117 |
+
def test_truncated_blob_raises(self) -> None:
|
| 118 |
+
blob = _make_blob(4, 28, 8, 128)
|
| 119 |
+
with pytest.raises(BlobParseError):
|
| 120 |
+
parse_state_blob(blob[:100], n_kv_heads=8, head_dim=128)
|
| 121 |
+
|
| 122 |
+
def test_bad_arch_length_raises(self) -> None:
|
| 123 |
+
blob = struct.pack("<I", 999) + b"x" * 100
|
| 124 |
+
with pytest.raises(BlobParseError, match="too large"):
|
| 125 |
+
parse_state_blob(blob, n_kv_heads=8, head_dim=128)
|
tests/test_block_pool.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — Block Pool Tests
|
| 3 |
+
Tests for 256-token block segmentation/assembly/extend.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from kvcos.core.block_pool import BlockPool, KVBlock
|
| 12 |
+
from kvcos.core.types import BLOCK_SIZE_TOKENS
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _kv(n_layers: int, n_heads: int, ctx: int, dim: int) -> tuple[torch.Tensor, torch.Tensor]:
|
| 16 |
+
k = torch.randn(n_layers, n_heads, ctx, dim, dtype=torch.float16)
|
| 17 |
+
return k, k.clone()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TestSegment:
|
| 21 |
+
"""Segment full KV cache into 256-token blocks."""
|
| 22 |
+
|
| 23 |
+
def test_exact_blocks(self) -> None:
|
| 24 |
+
keys, vals = _kv(32, 8, 512, 128)
|
| 25 |
+
pool = BlockPool(agent_id="a", model_id="m")
|
| 26 |
+
blocks = pool.segment(keys, vals)
|
| 27 |
+
assert len(blocks) == 2
|
| 28 |
+
assert all(b.is_full for b in blocks)
|
| 29 |
+
|
| 30 |
+
def test_partial_last_block(self) -> None:
|
| 31 |
+
keys, vals = _kv(32, 8, 300, 128)
|
| 32 |
+
pool = BlockPool(agent_id="a", model_id="m")
|
| 33 |
+
blocks = pool.segment(keys, vals)
|
| 34 |
+
assert len(blocks) == 2
|
| 35 |
+
assert blocks[0].is_full
|
| 36 |
+
assert not blocks[1].is_full
|
| 37 |
+
assert blocks[1].block_len == 44
|
| 38 |
+
|
| 39 |
+
def test_total_tokens(self) -> None:
|
| 40 |
+
keys, vals = _kv(32, 8, 700, 128)
|
| 41 |
+
pool = BlockPool(agent_id="a", model_id="m")
|
| 42 |
+
pool.segment(keys, vals)
|
| 43 |
+
assert pool.total_tokens == 700
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TestAssemble:
|
| 47 |
+
"""Assemble blocks back into full KV cache."""
|
| 48 |
+
|
| 49 |
+
def test_round_trip(self) -> None:
|
| 50 |
+
keys, vals = _kv(4, 2, 512, 64)
|
| 51 |
+
pool = BlockPool(agent_id="a", model_id="m")
|
| 52 |
+
pool.segment(keys, vals)
|
| 53 |
+
k_out, v_out = pool.assemble()
|
| 54 |
+
assert torch.equal(k_out, keys)
|
| 55 |
+
|
| 56 |
+
def test_subset_assembly(self) -> None:
|
| 57 |
+
keys, vals = _kv(4, 2, 768, 64)
|
| 58 |
+
pool = BlockPool(agent_id="a", model_id="m")
|
| 59 |
+
pool.segment(keys, vals)
|
| 60 |
+
k_out, _ = pool.assemble(block_indices=[0, 2])
|
| 61 |
+
assert k_out.shape[2] == BLOCK_SIZE_TOKENS * 2
|
| 62 |
+
|
| 63 |
+
def test_empty_raises(self) -> None:
|
| 64 |
+
pool = BlockPool(agent_id="a", model_id="m")
|
| 65 |
+
with pytest.raises(ValueError, match="No blocks"):
|
| 66 |
+
pool.assemble()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class TestExtend:
|
| 70 |
+
"""Extend pool with new tokens."""
|
| 71 |
+
|
| 72 |
+
def test_fills_partial_block(self) -> None:
|
| 73 |
+
keys, vals = _kv(4, 2, 200, 64)
|
| 74 |
+
pool = BlockPool(agent_id="a", model_id="m")
|
| 75 |
+
pool.segment(keys, vals)
|
| 76 |
+
assert not pool.blocks[-1].is_full
|
| 77 |
+
|
| 78 |
+
new_k, new_v = _kv(4, 2, 56, 64)
|
| 79 |
+
pool.extend(new_k, new_v)
|
| 80 |
+
assert pool.blocks[-1].is_full
|
| 81 |
+
assert pool.total_tokens == 256
|
| 82 |
+
|
| 83 |
+
def test_extend_creates_new_blocks(self) -> None:
|
| 84 |
+
keys, vals = _kv(4, 2, 256, 64)
|
| 85 |
+
pool = BlockPool(agent_id="a", model_id="m")
|
| 86 |
+
pool.segment(keys, vals)
|
| 87 |
+
assert pool.n_blocks == 1
|
| 88 |
+
|
| 89 |
+
new_k, new_v = _kv(4, 2, 300, 64)
|
| 90 |
+
pool.extend(new_k, new_v)
|
| 91 |
+
assert pool.n_blocks == 3
|
| 92 |
+
assert pool.total_tokens == 556
|
tests/test_chunker.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for kvcos.engram.chunker — markdown-aware semantic chunker."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from kvcos.engram.chunker import Chunk, chunk_markdown, eng_filename, slug_from_path
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TestChunkMarkdown:
|
| 9 |
+
def test_empty_content(self):
|
| 10 |
+
assert chunk_markdown("") == []
|
| 11 |
+
assert chunk_markdown(" ") == []
|
| 12 |
+
|
| 13 |
+
def test_small_file_single_chunk(self):
|
| 14 |
+
content = "# Title\n\nSome short content."
|
| 15 |
+
chunks = chunk_markdown(content, max_chars=2000)
|
| 16 |
+
assert len(chunks) == 1
|
| 17 |
+
assert chunks[0].index == 0
|
| 18 |
+
assert chunks[0].char_start == 0
|
| 19 |
+
assert chunks[0].char_end == len(content)
|
| 20 |
+
|
| 21 |
+
def test_large_file_splits(self):
|
| 22 |
+
# Create content that exceeds max_chars
|
| 23 |
+
content = "# Section 1\n\n" + "A" * 1500 + "\n\n# Section 2\n\n" + "B" * 1500
|
| 24 |
+
chunks = chunk_markdown(content, max_chars=2000)
|
| 25 |
+
assert len(chunks) >= 2
|
| 26 |
+
|
| 27 |
+
def test_chunks_cover_full_content(self):
|
| 28 |
+
content = "# A\n\nText A.\n\n# B\n\nText B.\n\n# C\n\nText C."
|
| 29 |
+
chunks = chunk_markdown(content, max_chars=15)
|
| 30 |
+
# All original content should be present across chunks
|
| 31 |
+
combined = " ".join(c.raw_text for c in chunks)
|
| 32 |
+
for word in ["Text A", "Text B", "Text C"]:
|
| 33 |
+
assert word in combined
|
| 34 |
+
|
| 35 |
+
def test_context_prefix(self):
|
| 36 |
+
content = "Hello world"
|
| 37 |
+
chunks = chunk_markdown(content, context_prefix="Source: test.md")
|
| 38 |
+
assert len(chunks) == 1
|
| 39 |
+
assert chunks[0].text.startswith("Source: test.md")
|
| 40 |
+
|
| 41 |
+
def test_indices_sequential(self):
|
| 42 |
+
content = "# A\n\n" + "X" * 3000 + "\n\n# B\n\n" + "Y" * 3000
|
| 43 |
+
chunks = chunk_markdown(content, max_chars=2000)
|
| 44 |
+
for i, chunk in enumerate(chunks):
|
| 45 |
+
assert chunk.index == i
|
| 46 |
+
|
| 47 |
+
def test_merge_small_sections(self):
|
| 48 |
+
"""Small consecutive sections should merge into one chunk."""
|
| 49 |
+
content = "# A\n\nShort.\n\n# B\n\nAlso short.\n\n# C\n\nStill short."
|
| 50 |
+
chunks = chunk_markdown(content, max_chars=2000, min_chars=100)
|
| 51 |
+
# All three small sections should merge into 1 chunk
|
| 52 |
+
assert len(chunks) == 1
|
| 53 |
+
|
| 54 |
+
def test_paragraph_split_fallback(self):
|
| 55 |
+
"""Content without headers should split on paragraphs."""
|
| 56 |
+
paragraphs = ["Paragraph " + str(i) + ". " + "X" * 500
|
| 57 |
+
for i in range(6)]
|
| 58 |
+
content = "\n\n".join(paragraphs)
|
| 59 |
+
chunks = chunk_markdown(content, max_chars=1500)
|
| 60 |
+
assert len(chunks) >= 2
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class TestSlugFromPath:
|
| 64 |
+
def test_simple_filename(self):
|
| 65 |
+
assert slug_from_path("readme.md") == "readme"
|
| 66 |
+
|
| 67 |
+
def test_uppercase_underscores(self):
|
| 68 |
+
assert slug_from_path("EIGENGRAM_SPEC.md") == "eigengram-spec"
|
| 69 |
+
|
| 70 |
+
def test_already_kebab(self):
|
| 71 |
+
assert slug_from_path("coding-style.md") == "coding-style"
|
| 72 |
+
|
| 73 |
+
def test_full_path(self):
|
| 74 |
+
assert slug_from_path("/Users/test/docs/my_doc.md") == "my-doc"
|
| 75 |
+
|
| 76 |
+
def test_special_chars(self):
|
| 77 |
+
assert slug_from_path("file (copy).md") == "file-copy"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class TestEngFilename:
|
| 81 |
+
def test_single_chunk(self):
|
| 82 |
+
name = eng_filename("engram", "readme", "2026-04-02")
|
| 83 |
+
assert name == "readme_2026-04-02.eng"
|
| 84 |
+
|
| 85 |
+
def test_multi_chunk(self):
|
| 86 |
+
name = eng_filename("engram", "geodesic3", "2026-04-02",
|
| 87 |
+
chunk_index=0, chunk_total=5)
|
| 88 |
+
assert name == "geodesic3_001_2026-04-02.eng"
|
| 89 |
+
|
| 90 |
+
def test_with_time(self):
|
| 91 |
+
name = eng_filename("engram", "session", "2026-04-02",
|
| 92 |
+
time_str="1430")
|
| 93 |
+
assert name == "session_2026-04-02_1430.eng"
|
| 94 |
+
|
| 95 |
+
def test_single_chunk_no_index(self):
|
| 96 |
+
"""Single-chunk files should not have chunk number."""
|
| 97 |
+
name = eng_filename("engram", "small", "2026-04-02",
|
| 98 |
+
chunk_index=0, chunk_total=1)
|
| 99 |
+
assert name == "small_2026-04-02.eng"
|
tests/test_compression.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — Compression Tests
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
Tests for kvcos.core.compression:
|
| 6 |
+
- FP16 passthrough
|
| 7 |
+
- Q8_0 round-trip accuracy & shape preservation
|
| 8 |
+
- PolarQuant round-trip accuracy & rotation invariants
|
| 9 |
+
- Dispatcher routing and Q4_0 fallback warning
|
| 10 |
+
- Edge cases: padding, single-element groups
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import warnings
|
| 16 |
+
|
| 17 |
+
import pytest
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from kvcos.core.compression import (
|
| 21 |
+
Q8_GROUP_SIZE,
|
| 22 |
+
CompressionResult,
|
| 23 |
+
compress,
|
| 24 |
+
compress_fp16,
|
| 25 |
+
compress_polarquant,
|
| 26 |
+
compress_q8_0,
|
| 27 |
+
decompress,
|
| 28 |
+
decompress_fp16,
|
| 29 |
+
decompress_polarquant,
|
| 30 |
+
decompress_q8_0,
|
| 31 |
+
)
|
| 32 |
+
from kvcos.core.types import CompressionMethod
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ── FP16 Passthrough ──────────────────────────────────────────────────────────
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TestFP16:
|
| 39 |
+
"""FP16 passthrough: no quantization, just dtype normalization."""
|
| 40 |
+
|
| 41 |
+
def test_fp16_passthrough_shape(
|
| 42 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 43 |
+
) -> None:
|
| 44 |
+
keys, _ = llama_kv_256
|
| 45 |
+
result = compress_fp16(keys)
|
| 46 |
+
assert result.data.shape == keys.shape
|
| 47 |
+
|
| 48 |
+
def test_fp16_passthrough_dtype(
|
| 49 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 50 |
+
) -> None:
|
| 51 |
+
keys, _ = llama_kv_256
|
| 52 |
+
result = compress_fp16(keys)
|
| 53 |
+
assert result.data.dtype == torch.float16
|
| 54 |
+
|
| 55 |
+
def test_fp16_passthrough_exact(
|
| 56 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 57 |
+
) -> None:
|
| 58 |
+
keys, _ = llama_kv_256
|
| 59 |
+
result = compress_fp16(keys)
|
| 60 |
+
assert torch.equal(result.data, keys.to(torch.float16))
|
| 61 |
+
|
| 62 |
+
def test_fp16_compression_ratio_one(
|
| 63 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 64 |
+
) -> None:
|
| 65 |
+
keys, _ = llama_kv_256
|
| 66 |
+
result = compress_fp16(keys)
|
| 67 |
+
assert result.compression_ratio == 1.0
|
| 68 |
+
|
| 69 |
+
def test_fp16_method_tag(
|
| 70 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 71 |
+
) -> None:
|
| 72 |
+
keys, _ = llama_kv_256
|
| 73 |
+
result = compress_fp16(keys)
|
| 74 |
+
assert result.method == CompressionMethod.FP16
|
| 75 |
+
|
| 76 |
+
def test_fp16_from_fp32(self) -> None:
|
| 77 |
+
"""FP32 input is cast to FP16."""
|
| 78 |
+
t = torch.randn(4, 8, 32, 128, dtype=torch.float32)
|
| 79 |
+
result = compress_fp16(t)
|
| 80 |
+
assert result.data.dtype == torch.float16
|
| 81 |
+
assert result.original_dtype == torch.float32
|
| 82 |
+
|
| 83 |
+
def test_fp16_decompress_identity(
|
| 84 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 85 |
+
) -> None:
|
| 86 |
+
keys, _ = llama_kv_256
|
| 87 |
+
result = compress_fp16(keys)
|
| 88 |
+
out = decompress_fp16(result.data)
|
| 89 |
+
assert torch.equal(out, result.data)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ── Q8_0 Quantization ────────────────────────────────────────────────────────
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class TestQ8_0:
|
| 96 |
+
"""Q8_0: group quantization matching llama.cpp GGML_TYPE_Q8_0."""
|
| 97 |
+
|
| 98 |
+
def test_q8_0_shape_preserved(
|
| 99 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 100 |
+
) -> None:
|
| 101 |
+
keys, _ = llama_kv_256
|
| 102 |
+
result = compress_q8_0(keys)
|
| 103 |
+
assert result.data.shape == keys.shape
|
| 104 |
+
|
| 105 |
+
def test_q8_0_output_dtype(
|
| 106 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 107 |
+
) -> None:
|
| 108 |
+
"""Q8_0 stores dequantized bfloat16 for safetensors compat."""
|
| 109 |
+
keys, _ = llama_kv_256
|
| 110 |
+
result = compress_q8_0(keys)
|
| 111 |
+
assert result.data.dtype == torch.bfloat16
|
| 112 |
+
|
| 113 |
+
def test_q8_0_method_tag(
|
| 114 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 115 |
+
) -> None:
|
| 116 |
+
keys, _ = llama_kv_256
|
| 117 |
+
result = compress_q8_0(keys)
|
| 118 |
+
assert result.method == CompressionMethod.Q8_0
|
| 119 |
+
|
| 120 |
+
def test_q8_0_metadata_group_size(
|
| 121 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 122 |
+
) -> None:
|
| 123 |
+
keys, _ = llama_kv_256
|
| 124 |
+
result = compress_q8_0(keys)
|
| 125 |
+
assert result.metadata["q8_group_size"] == str(Q8_GROUP_SIZE)
|
| 126 |
+
|
| 127 |
+
def test_q8_0_round_trip_low_error(
|
| 128 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 129 |
+
) -> None:
|
| 130 |
+
"""Q8_0 quantization error should be < 1% relative MSE."""
|
| 131 |
+
keys, _ = llama_kv_256
|
| 132 |
+
result = compress_q8_0(keys)
|
| 133 |
+
decompressed = decompress_q8_0(result.data)
|
| 134 |
+
|
| 135 |
+
original = keys.float()
|
| 136 |
+
restored = decompressed.float()
|
| 137 |
+
|
| 138 |
+
mse = ((original - restored) ** 2).mean()
|
| 139 |
+
signal_power = (original**2).mean()
|
| 140 |
+
relative_mse = (mse / signal_power).item()
|
| 141 |
+
assert relative_mse < 0.01, f"Q8_0 relative MSE {relative_mse:.6f} > 1%"
|
| 142 |
+
|
| 143 |
+
def test_q8_0_round_trip_values(
|
| 144 |
+
self, phi3_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 145 |
+
) -> None:
|
| 146 |
+
"""Q8_0 round-trip on Phi-3 (head_dim=96, needs padding)."""
|
| 147 |
+
keys, values = phi3_kv_256
|
| 148 |
+
for tensor in (keys, values):
|
| 149 |
+
result = compress_q8_0(tensor)
|
| 150 |
+
assert result.data.shape == tensor.shape
|
| 151 |
+
|
| 152 |
+
def test_q8_0_compression_ratio_fp32(self) -> None:
|
| 153 |
+
"""FP32 input → bfloat16 output gives 2x compression ratio."""
|
| 154 |
+
t = torch.randn(2, 4, 64, 128, dtype=torch.float32)
|
| 155 |
+
result = compress_q8_0(t)
|
| 156 |
+
assert abs(result.compression_ratio - 2.0) < 0.01
|
| 157 |
+
|
| 158 |
+
def test_q8_0_compression_ratio_fp16(
|
| 159 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 160 |
+
) -> None:
|
| 161 |
+
"""FP16 input → bfloat16 output gives 1x ratio (same byte width)."""
|
| 162 |
+
keys, _ = llama_kv_256
|
| 163 |
+
result = compress_q8_0(keys)
|
| 164 |
+
assert abs(result.compression_ratio - 1.0) < 0.01
|
| 165 |
+
|
| 166 |
+
def test_q8_0_preserves_original_dtype(self) -> None:
|
| 167 |
+
t = torch.randn(4, 8, 32, 128, dtype=torch.float32)
|
| 168 |
+
result = compress_q8_0(t)
|
| 169 |
+
assert result.original_dtype == torch.float32
|
| 170 |
+
|
| 171 |
+
def test_q8_0_padding_dim_not_divisible(self) -> None:
|
| 172 |
+
"""Head dims not divisible by 32 get padded then unpadded."""
|
| 173 |
+
t = torch.randn(2, 4, 16, 96, dtype=torch.float16) # 96 = 3*32, exact
|
| 174 |
+
result = compress_q8_0(t)
|
| 175 |
+
assert result.data.shape == t.shape
|
| 176 |
+
|
| 177 |
+
t2 = torch.randn(2, 4, 16, 100, dtype=torch.float16) # 100 not div by 32
|
| 178 |
+
result2 = compress_q8_0(t2)
|
| 179 |
+
assert result2.data.shape == t2.shape
|
| 180 |
+
|
| 181 |
+
def test_q8_0_zero_tensor(self) -> None:
|
| 182 |
+
"""All-zero tensor should round-trip exactly."""
|
| 183 |
+
t = torch.zeros(2, 4, 16, 128, dtype=torch.float16)
|
| 184 |
+
result = compress_q8_0(t)
|
| 185 |
+
decompressed = decompress_q8_0(result.data)
|
| 186 |
+
assert torch.allclose(decompressed, t.to(torch.float16), atol=1e-6)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# ── PolarQuant ───────────────────────────────────────────────────────────────
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class TestPolarQuant:
|
| 193 |
+
"""PolarQuant: MSE-optimal random rotation + Lloyd-Max at 3 bits.
|
| 194 |
+
QJL intentionally absent (D5).
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def test_polarquant_shape_preserved(
|
| 198 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 199 |
+
) -> None:
|
| 200 |
+
keys, _ = llama_kv_256
|
| 201 |
+
result = compress_polarquant(keys)
|
| 202 |
+
assert result.data.shape == keys.shape
|
| 203 |
+
|
| 204 |
+
def test_polarquant_output_dtype(
|
| 205 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 206 |
+
) -> None:
|
| 207 |
+
keys, _ = llama_kv_256
|
| 208 |
+
result = compress_polarquant(keys)
|
| 209 |
+
assert result.data.dtype == torch.bfloat16
|
| 210 |
+
|
| 211 |
+
def test_polarquant_method_tag(
|
| 212 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 213 |
+
) -> None:
|
| 214 |
+
keys, _ = llama_kv_256
|
| 215 |
+
result = compress_polarquant(keys)
|
| 216 |
+
assert result.method == CompressionMethod.POLARQUANT
|
| 217 |
+
|
| 218 |
+
def test_polarquant_metadata_qjl_disabled(
|
| 219 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 220 |
+
) -> None:
|
| 221 |
+
"""D5: QJL must be marked disabled in metadata."""
|
| 222 |
+
keys, _ = llama_kv_256
|
| 223 |
+
result = compress_polarquant(keys)
|
| 224 |
+
assert result.metadata["qjl_enabled"] == "false"
|
| 225 |
+
assert result.metadata["polarquant_bits"] == "3"
|
| 226 |
+
|
| 227 |
+
def test_polarquant_round_trip_bounded_error(
|
| 228 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 229 |
+
) -> None:
|
| 230 |
+
"""PolarQuant 3-bit error should be < 15% relative MSE.
|
| 231 |
+
|
| 232 |
+
3-bit Lloyd-Max on rotated Gaussian: theoretical ~10% for 8 centroids.
|
| 233 |
+
Allow margin for rotation + dtype casting.
|
| 234 |
+
"""
|
| 235 |
+
keys, _ = llama_kv_256
|
| 236 |
+
result = compress_polarquant(keys)
|
| 237 |
+
decompressed = decompress_polarquant(result.data)
|
| 238 |
+
|
| 239 |
+
original = keys.float()
|
| 240 |
+
restored = decompressed.float()
|
| 241 |
+
|
| 242 |
+
mse = ((original - restored) ** 2).mean()
|
| 243 |
+
signal_power = (original**2).mean()
|
| 244 |
+
relative_mse = (mse / signal_power).item()
|
| 245 |
+
assert relative_mse < 0.15, f"PolarQuant relative MSE {relative_mse:.4f} > 15%"
|
| 246 |
+
|
| 247 |
+
def test_polarquant_worse_than_q8_0(
|
| 248 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 249 |
+
) -> None:
|
| 250 |
+
"""3-bit PolarQuant should have higher error than 8-bit Q8_0."""
|
| 251 |
+
keys, _ = llama_kv_256
|
| 252 |
+
original = keys.float()
|
| 253 |
+
|
| 254 |
+
q8_result = compress_q8_0(keys)
|
| 255 |
+
pq_result = compress_polarquant(keys)
|
| 256 |
+
|
| 257 |
+
q8_mse = ((original - decompress_q8_0(q8_result.data).float()) ** 2).mean()
|
| 258 |
+
pq_mse = (
|
| 259 |
+
(original - decompress_polarquant(pq_result.data).float()) ** 2
|
| 260 |
+
).mean()
|
| 261 |
+
|
| 262 |
+
assert pq_mse > q8_mse, "PolarQuant 3-bit should be less accurate than Q8_0"
|
| 263 |
+
|
| 264 |
+
def test_polarquant_deterministic(
|
| 265 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 266 |
+
) -> None:
|
| 267 |
+
"""Same input → same output (fixed seed rotation matrix)."""
|
| 268 |
+
keys, _ = llama_kv_256
|
| 269 |
+
r1 = compress_polarquant(keys)
|
| 270 |
+
r2 = compress_polarquant(keys)
|
| 271 |
+
assert torch.equal(r1.data, r2.data)
|
| 272 |
+
|
| 273 |
+
def test_polarquant_phi3_shape(
|
| 274 |
+
self, phi3_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 275 |
+
) -> None:
|
| 276 |
+
"""Phi-3 head_dim=96 works with PolarQuant."""
|
| 277 |
+
keys, _ = phi3_kv_256
|
| 278 |
+
result = compress_polarquant(keys)
|
| 279 |
+
assert result.data.shape == keys.shape
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# ── Dispatcher ───────────────────────────────────────────────────────────────
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class TestDispatcher:
|
| 286 |
+
"""compress() and decompress() dispatch to correct implementations."""
|
| 287 |
+
|
| 288 |
+
@pytest.mark.parametrize(
|
| 289 |
+
"method",
|
| 290 |
+
[CompressionMethod.FP16, CompressionMethod.Q8_0, CompressionMethod.POLARQUANT],
|
| 291 |
+
)
|
| 292 |
+
def test_compress_dispatches(self, method: CompressionMethod) -> None:
|
| 293 |
+
t = torch.randn(2, 4, 16, 128, dtype=torch.float16)
|
| 294 |
+
result = compress(t, method)
|
| 295 |
+
assert isinstance(result, CompressionResult)
|
| 296 |
+
assert result.method == method
|
| 297 |
+
|
| 298 |
+
@pytest.mark.parametrize(
|
| 299 |
+
"method",
|
| 300 |
+
[CompressionMethod.FP16, CompressionMethod.Q8_0, CompressionMethod.POLARQUANT],
|
| 301 |
+
)
|
| 302 |
+
def test_decompress_returns_fp16(self, method: CompressionMethod) -> None:
|
| 303 |
+
t = torch.randn(2, 4, 16, 128, dtype=torch.float16)
|
| 304 |
+
result = compress(t, method)
|
| 305 |
+
out = decompress(result.data, method)
|
| 306 |
+
assert out.dtype == torch.float16
|
| 307 |
+
|
| 308 |
+
def test_q4_0_warns_and_falls_back(self) -> None:
|
| 309 |
+
"""D5: Q4_0 emits warning and uses Q8_0 instead."""
|
| 310 |
+
t = torch.randn(2, 4, 16, 128, dtype=torch.float16)
|
| 311 |
+
with warnings.catch_warnings(record=True) as w:
|
| 312 |
+
warnings.simplefilter("always")
|
| 313 |
+
result = compress(t, CompressionMethod.Q4_0)
|
| 314 |
+
assert len(w) == 1
|
| 315 |
+
assert "Q4_0" in str(w[0].message)
|
| 316 |
+
assert "92%" in str(w[0].message)
|
| 317 |
+
assert result.method == CompressionMethod.Q8_0
|
| 318 |
+
|
| 319 |
+
def test_unknown_method_raises(self) -> None:
|
| 320 |
+
t = torch.randn(2, 4, 16, 128, dtype=torch.float16)
|
| 321 |
+
with pytest.raises(ValueError, match="Unknown compression method"):
|
| 322 |
+
compress(t, "invalid_method") # type: ignore[arg-type]
|
| 323 |
+
|
| 324 |
+
def test_decompress_unknown_raises(self) -> None:
|
| 325 |
+
t = torch.randn(2, 4, 16, 128, dtype=torch.float16)
|
| 326 |
+
with pytest.raises(ValueError, match="Unknown compression method"):
|
| 327 |
+
decompress(t, "invalid_method") # type: ignore[arg-type]
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
# ── Round-trip Integration ───────────────────────────────────────────────────
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
class TestRoundTrip:
|
| 334 |
+
"""Full compress → decompress round-trip through dispatcher."""
|
| 335 |
+
|
| 336 |
+
@pytest.mark.parametrize(
|
| 337 |
+
"method",
|
| 338 |
+
[CompressionMethod.FP16, CompressionMethod.Q8_0, CompressionMethod.POLARQUANT],
|
| 339 |
+
)
|
| 340 |
+
def test_round_trip_shape_preserved(self, method: CompressionMethod) -> None:
|
| 341 |
+
t = torch.randn(4, 8, 64, 128, dtype=torch.float16)
|
| 342 |
+
result = compress(t, method)
|
| 343 |
+
out = decompress(result.data, method)
|
| 344 |
+
assert out.shape == t.shape
|
| 345 |
+
|
| 346 |
+
def test_round_trip_both_kv(
|
| 347 |
+
self, llama_kv_256: tuple[torch.Tensor, torch.Tensor]
|
| 348 |
+
) -> None:
|
| 349 |
+
"""Compress and decompress both keys and values."""
|
| 350 |
+
keys, values = llama_kv_256
|
| 351 |
+
for tensor in (keys, values):
|
| 352 |
+
for method in (CompressionMethod.FP16, CompressionMethod.Q8_0):
|
| 353 |
+
result = compress(tensor, method)
|
| 354 |
+
out = decompress(result.data, method)
|
| 355 |
+
assert out.shape == tensor.shape
|
| 356 |
+
assert out.dtype == torch.float16
|
tests/test_eigengram.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
EIGENGRAM test suite — no model calls, pure format verification.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import struct
|
| 9 |
+
|
| 10 |
+
import pytest
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from kvcos.engram.format import (
|
| 14 |
+
EigramDecoder,
|
| 15 |
+
EigramEncoder,
|
| 16 |
+
EIGENGRAM_MAGIC,
|
| 17 |
+
EIGENGRAM_VERSION,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
BASIS_PATH = "results/corpus_basis_fcdb_v2.pt"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@pytest.fixture(scope="module")
|
| 24 |
+
def basis():
|
| 25 |
+
if not os.path.exists(BASIS_PATH):
|
| 26 |
+
pytest.skip("FCDB v2 basis not built yet")
|
| 27 |
+
return torch.load(BASIS_PATH, weights_only=False)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@pytest.fixture(scope="module")
|
| 31 |
+
def sample_cert(basis):
|
| 32 |
+
enc = EigramEncoder()
|
| 33 |
+
R = basis["basis"].shape[0]
|
| 34 |
+
return enc.encode(
|
| 35 |
+
vec_perdoc=torch.randn(R),
|
| 36 |
+
vec_fcdb=torch.randn(R),
|
| 37 |
+
joint_center=basis["joint_center"],
|
| 38 |
+
corpus_hash="a" * 32,
|
| 39 |
+
model_id="Llama-3.1-8B",
|
| 40 |
+
basis_rank=R,
|
| 41 |
+
n_corpus=200,
|
| 42 |
+
layer_range=(8, 24),
|
| 43 |
+
context_len=512,
|
| 44 |
+
l2_norm=1.234,
|
| 45 |
+
scs=0.42,
|
| 46 |
+
margin_proof=0.013,
|
| 47 |
+
task_description="Test document for transformer attention.",
|
| 48 |
+
cache_id="test-doc-001",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class TestFormat:
|
| 53 |
+
def test_magic_present(self, sample_cert: bytes) -> None:
|
| 54 |
+
assert sample_cert[:4] == EIGENGRAM_MAGIC
|
| 55 |
+
|
| 56 |
+
def test_version_byte(self, sample_cert: bytes) -> None:
|
| 57 |
+
assert struct.unpack_from("<B", sample_cert, 4)[0] == EIGENGRAM_VERSION
|
| 58 |
+
|
| 59 |
+
def test_minimum_size(self, sample_cert: bytes, basis) -> None:
|
| 60 |
+
R = basis["basis"].shape[0]
|
| 61 |
+
min_size = 99 + R * 2 + R * 2 + 128 * 2
|
| 62 |
+
assert len(sample_cert) >= min_size
|
| 63 |
+
|
| 64 |
+
def test_file_size_reasonable(self, sample_cert: bytes) -> None:
|
| 65 |
+
assert len(sample_cert) < 2048
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class TestRoundTrip:
|
| 69 |
+
def test_model_id(self, sample_cert: bytes) -> None:
|
| 70 |
+
rec = EigramDecoder().decode(sample_cert)
|
| 71 |
+
assert rec["model_id"] == "Llama-3.1-8B"
|
| 72 |
+
|
| 73 |
+
def test_basis_rank(self, sample_cert: bytes, basis) -> None:
|
| 74 |
+
rec = EigramDecoder().decode(sample_cert)
|
| 75 |
+
assert rec["basis_rank"] == basis["basis"].shape[0]
|
| 76 |
+
|
| 77 |
+
def test_vec_perdoc_shape(self, sample_cert: bytes, basis) -> None:
|
| 78 |
+
rec = EigramDecoder().decode(sample_cert)
|
| 79 |
+
assert rec["vec_perdoc"].shape == (basis["basis"].shape[0],)
|
| 80 |
+
|
| 81 |
+
def test_vec_fcdb_shape(self, sample_cert: bytes, basis) -> None:
|
| 82 |
+
rec = EigramDecoder().decode(sample_cert)
|
| 83 |
+
assert rec["vec_fcdb"].shape == (basis["basis"].shape[0],)
|
| 84 |
+
|
| 85 |
+
def test_joint_center_shape(self, sample_cert: bytes) -> None:
|
| 86 |
+
rec = EigramDecoder().decode(sample_cert)
|
| 87 |
+
assert rec["joint_center"].shape == (128,)
|
| 88 |
+
|
| 89 |
+
def test_scs(self, sample_cert: bytes) -> None:
|
| 90 |
+
rec = EigramDecoder().decode(sample_cert)
|
| 91 |
+
assert abs(rec["scs"] - 0.42) < 0.01
|
| 92 |
+
|
| 93 |
+
def test_margin_proof(self, sample_cert: bytes) -> None:
|
| 94 |
+
rec = EigramDecoder().decode(sample_cert)
|
| 95 |
+
assert abs(rec["margin_proof"] - 0.013) < 0.001
|
| 96 |
+
|
| 97 |
+
def test_task_description(self, sample_cert: bytes) -> None:
|
| 98 |
+
rec = EigramDecoder().decode(sample_cert)
|
| 99 |
+
assert "transformer" in rec["task_description"]
|
| 100 |
+
|
| 101 |
+
def test_cache_id(self, sample_cert: bytes) -> None:
|
| 102 |
+
rec = EigramDecoder().decode(sample_cert)
|
| 103 |
+
assert rec["cache_id"] == "test-doc-001"
|
| 104 |
+
|
| 105 |
+
def test_layer_range(self, sample_cert: bytes) -> None:
|
| 106 |
+
rec = EigramDecoder().decode(sample_cert)
|
| 107 |
+
assert rec["layer_range"] == (8, 24)
|
| 108 |
+
|
| 109 |
+
def test_n_corpus(self, sample_cert: bytes) -> None:
|
| 110 |
+
rec = EigramDecoder().decode(sample_cert)
|
| 111 |
+
assert rec["n_corpus"] == 200
|
| 112 |
+
|
| 113 |
+
def test_context_len(self, sample_cert: bytes) -> None:
|
| 114 |
+
rec = EigramDecoder().decode(sample_cert)
|
| 115 |
+
assert rec["context_len"] == 512
|
| 116 |
+
|
| 117 |
+
def test_float16_cosine_preserved(self, basis) -> None:
|
| 118 |
+
enc = EigramEncoder()
|
| 119 |
+
R = basis["basis"].shape[0]
|
| 120 |
+
v = torch.randn(R)
|
| 121 |
+
v = v / v.norm()
|
| 122 |
+
cert = enc.encode(
|
| 123 |
+
vec_perdoc=v, vec_fcdb=v,
|
| 124 |
+
joint_center=basis["joint_center"],
|
| 125 |
+
corpus_hash="a" * 32, model_id="test",
|
| 126 |
+
basis_rank=R, n_corpus=200,
|
| 127 |
+
layer_range=(8, 24), context_len=0,
|
| 128 |
+
l2_norm=1.0, scs=0.5, margin_proof=0.0,
|
| 129 |
+
task_description="cosine test", cache_id="cos",
|
| 130 |
+
)
|
| 131 |
+
rec = EigramDecoder().decode(cert)
|
| 132 |
+
cos = torch.nn.functional.cosine_similarity(
|
| 133 |
+
v.unsqueeze(0), rec["vec_perdoc"].unsqueeze(0)
|
| 134 |
+
).item()
|
| 135 |
+
assert cos > 0.999, f"Cosine after round-trip: {cos:.5f}"
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class TestErrorHandling:
|
| 139 |
+
def test_bad_magic_raises(self) -> None:
|
| 140 |
+
bad = b"XXXX" + b"\x00" * 200
|
| 141 |
+
with pytest.raises(ValueError, match="magic"):
|
| 142 |
+
EigramDecoder().decode(bad)
|
| 143 |
+
|
| 144 |
+
def test_wrong_version_raises(self, sample_cert: bytes) -> None:
|
| 145 |
+
data = bytearray(sample_cert)
|
| 146 |
+
data[4] = 99
|
| 147 |
+
with pytest.raises(ValueError, match="version"):
|
| 148 |
+
EigramDecoder().decode(bytes(data))
|
| 149 |
+
|
| 150 |
+
def test_truncated_raises(self, sample_cert: bytes) -> None:
|
| 151 |
+
with pytest.raises(Exception):
|
| 152 |
+
EigramDecoder().decode(sample_cert[:20])
|
tests/test_embedder.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for kvcos.engram.embedder — unified fingerprint embedding."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from kvcos.engram.embedder import (
|
| 8 |
+
HashEmbedder,
|
| 9 |
+
get_embedder,
|
| 10 |
+
get_fingerprint,
|
| 11 |
+
reset_embedder,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TestHashEmbedder:
|
| 16 |
+
def test_deterministic(self):
|
| 17 |
+
emb = HashEmbedder(dim=128)
|
| 18 |
+
fp1 = emb.embed("hello")
|
| 19 |
+
fp2 = emb.embed("hello")
|
| 20 |
+
assert torch.allclose(fp1, fp2)
|
| 21 |
+
|
| 22 |
+
def test_different_text(self):
|
| 23 |
+
emb = HashEmbedder(dim=128)
|
| 24 |
+
fp1 = emb.embed("hello")
|
| 25 |
+
fp2 = emb.embed("world")
|
| 26 |
+
assert not torch.allclose(fp1, fp2)
|
| 27 |
+
|
| 28 |
+
def test_normalized(self):
|
| 29 |
+
emb = HashEmbedder(dim=128)
|
| 30 |
+
fp = emb.embed("test")
|
| 31 |
+
norm = torch.norm(fp).item()
|
| 32 |
+
assert abs(norm - 1.0) < 0.01
|
| 33 |
+
|
| 34 |
+
def test_dimension(self):
|
| 35 |
+
emb = HashEmbedder(dim=256)
|
| 36 |
+
fp = emb.embed("test")
|
| 37 |
+
assert fp.shape == (256,)
|
| 38 |
+
assert emb.dim == 256
|
| 39 |
+
|
| 40 |
+
def test_source_tag(self):
|
| 41 |
+
emb = HashEmbedder()
|
| 42 |
+
assert emb.source == "hash-fallback"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TestGetFingerprint:
|
| 46 |
+
def test_returns_tensor_and_source(self):
|
| 47 |
+
fp, source = get_fingerprint("test text")
|
| 48 |
+
assert isinstance(fp, torch.Tensor)
|
| 49 |
+
assert isinstance(source, str)
|
| 50 |
+
assert source in ("llama_cpp", "sbert", "hash-fallback")
|
| 51 |
+
|
| 52 |
+
def test_deterministic(self):
|
| 53 |
+
fp1, _ = get_fingerprint("same text")
|
| 54 |
+
fp2, _ = get_fingerprint("same text")
|
| 55 |
+
assert torch.allclose(fp1, fp2)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class TestSBertEmbedder:
|
| 59 |
+
"""Test sbert if available (installed in this venv)."""
|
| 60 |
+
|
| 61 |
+
def test_sbert_available(self):
|
| 62 |
+
"""Verify sentence-transformers is usable."""
|
| 63 |
+
try:
|
| 64 |
+
from kvcos.engram.embedder import SBertEmbedder
|
| 65 |
+
emb = SBertEmbedder()
|
| 66 |
+
assert emb.source == "sbert"
|
| 67 |
+
assert emb.dim == 384
|
| 68 |
+
except ImportError:
|
| 69 |
+
pytest.skip("sentence-transformers not installed")
|
| 70 |
+
|
| 71 |
+
def test_semantic_discrimination(self):
|
| 72 |
+
"""Related texts should be more similar than unrelated."""
|
| 73 |
+
try:
|
| 74 |
+
from kvcos.engram.embedder import SBertEmbedder
|
| 75 |
+
emb = SBertEmbedder()
|
| 76 |
+
except ImportError:
|
| 77 |
+
pytest.skip("sentence-transformers not installed")
|
| 78 |
+
|
| 79 |
+
fp_a = emb.embed("machine learning neural network training")
|
| 80 |
+
fp_b = emb.embed("deep learning model optimization")
|
| 81 |
+
fp_c = emb.embed("chocolate cake baking recipe")
|
| 82 |
+
|
| 83 |
+
sim_ab = F.cosine_similarity(fp_a.unsqueeze(0), fp_b.unsqueeze(0)).item()
|
| 84 |
+
sim_ac = F.cosine_similarity(fp_a.unsqueeze(0), fp_c.unsqueeze(0)).item()
|
| 85 |
+
|
| 86 |
+
assert sim_ab > sim_ac, (
|
| 87 |
+
f"Related topics ({sim_ab:.4f}) should be more similar "
|
| 88 |
+
f"than unrelated ({sim_ac:.4f})"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class TestGetEmbedder:
|
| 93 |
+
def test_singleton(self):
|
| 94 |
+
reset_embedder()
|
| 95 |
+
e1 = get_embedder()
|
| 96 |
+
e2 = get_embedder()
|
| 97 |
+
assert e1 is e2
|
| 98 |
+
|
| 99 |
+
def test_reset(self):
|
| 100 |
+
reset_embedder()
|
| 101 |
+
e1 = get_embedder()
|
| 102 |
+
reset_embedder()
|
| 103 |
+
e2 = get_embedder()
|
| 104 |
+
# After reset, a new instance is created
|
| 105 |
+
# (may or may not be same object depending on strategy)
|
| 106 |
+
assert e2 is not None
|
tests/test_hnsw_index.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for kvcos.engram.hnsw_index — HNSW nearest-neighbor index."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
from kvcos.engram.hnsw_index import EngramIndex, HNSWResult
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@pytest.fixture
|
| 10 |
+
def small_index():
|
| 11 |
+
"""Build a small 8-dim HNSW index with 5 documents."""
|
| 12 |
+
idx = EngramIndex(dim=8)
|
| 13 |
+
ids = [f"doc_{i}" for i in range(5)]
|
| 14 |
+
# deterministic orthogonal-ish vectors
|
| 15 |
+
vecs = torch.eye(5, 8)
|
| 16 |
+
idx.add_batch(ids, vecs)
|
| 17 |
+
return idx
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TestEngramIndexBuild:
|
| 21 |
+
def test_add_batch_len(self, small_index):
|
| 22 |
+
assert len(small_index) == 5
|
| 23 |
+
|
| 24 |
+
def test_add_batch_ids_stored(self, small_index):
|
| 25 |
+
assert small_index._ids == [f"doc_{i}" for i in range(5)]
|
| 26 |
+
|
| 27 |
+
def test_repr(self, small_index):
|
| 28 |
+
r = repr(small_index)
|
| 29 |
+
assert "n=5" in r
|
| 30 |
+
assert "dim=8" in r
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class TestEngramIndexSearch:
|
| 34 |
+
def test_search_returns_results(self, small_index):
|
| 35 |
+
query = torch.eye(5, 8)[0] # matches doc_0
|
| 36 |
+
results = small_index.search(query, top_k=3)
|
| 37 |
+
assert len(results) == 3
|
| 38 |
+
assert results[0].doc_id == "doc_0"
|
| 39 |
+
|
| 40 |
+
def test_search_scores_descending(self, small_index):
|
| 41 |
+
query = torch.eye(5, 8)[2]
|
| 42 |
+
results = small_index.search(query, top_k=5)
|
| 43 |
+
scores = [r.score for r in results]
|
| 44 |
+
assert scores == sorted(scores, reverse=True)
|
| 45 |
+
|
| 46 |
+
def test_search_margin(self, small_index):
|
| 47 |
+
query = torch.eye(5, 8)[0]
|
| 48 |
+
results = small_index.search(query, top_k=3)
|
| 49 |
+
assert results[0].margin >= 0
|
| 50 |
+
|
| 51 |
+
def test_search_raises_before_build(self):
|
| 52 |
+
idx = EngramIndex(dim=8)
|
| 53 |
+
with pytest.raises(RuntimeError, match="not built"):
|
| 54 |
+
idx.search(torch.randn(8), top_k=1)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TestEngramIndexGetVector:
|
| 58 |
+
def test_get_vector_returns_tensor(self, small_index):
|
| 59 |
+
vec = small_index.get_vector("doc_0")
|
| 60 |
+
assert vec is not None
|
| 61 |
+
assert isinstance(vec, torch.Tensor)
|
| 62 |
+
assert vec.shape == (8,)
|
| 63 |
+
|
| 64 |
+
def test_get_vector_none_for_missing(self, small_index):
|
| 65 |
+
vec = small_index.get_vector("nonexistent")
|
| 66 |
+
assert vec is None
|
| 67 |
+
|
| 68 |
+
def test_get_vector_reconstructs_normalized(self, small_index):
|
| 69 |
+
"""Vectors are L2-normalized on add, so reconstruction should be unit-length."""
|
| 70 |
+
vec = small_index.get_vector("doc_0")
|
| 71 |
+
norm = torch.norm(vec).item()
|
| 72 |
+
assert abs(norm - 1.0) < 0.01
|
| 73 |
+
|
| 74 |
+
def test_get_vector_matches_original_direction(self, small_index):
|
| 75 |
+
"""Reconstructed vector should point in the same direction as the original."""
|
| 76 |
+
original = torch.nn.functional.normalize(torch.eye(5, 8)[3:4], dim=-1)[0]
|
| 77 |
+
reconstructed = small_index.get_vector("doc_3")
|
| 78 |
+
cosine = torch.dot(original, reconstructed).item()
|
| 79 |
+
assert cosine > 0.99
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TestEngramIndexPersistence:
|
| 83 |
+
def test_save_and_load(self, small_index, tmp_path):
|
| 84 |
+
path = str(tmp_path / "test_hnsw")
|
| 85 |
+
small_index.save(path)
|
| 86 |
+
|
| 87 |
+
loaded = EngramIndex.load(path)
|
| 88 |
+
assert len(loaded) == 5
|
| 89 |
+
assert loaded._ids == small_index._ids
|
| 90 |
+
|
| 91 |
+
def test_loaded_search_matches_original(self, small_index, tmp_path):
|
| 92 |
+
path = str(tmp_path / "test_hnsw")
|
| 93 |
+
small_index.save(path)
|
| 94 |
+
loaded = EngramIndex.load(path)
|
| 95 |
+
|
| 96 |
+
query = torch.eye(5, 8)[1]
|
| 97 |
+
orig_results = small_index.search(query, top_k=3)
|
| 98 |
+
load_results = loaded.search(query, top_k=3)
|
| 99 |
+
assert [r.doc_id for r in orig_results] == [r.doc_id for r in load_results]
|
| 100 |
+
|
| 101 |
+
def test_loaded_get_vector(self, small_index, tmp_path):
|
| 102 |
+
path = str(tmp_path / "test_hnsw")
|
| 103 |
+
small_index.save(path)
|
| 104 |
+
loaded = EngramIndex.load(path)
|
| 105 |
+
|
| 106 |
+
vec = loaded.get_vector("doc_2")
|
| 107 |
+
assert vec is not None
|
| 108 |
+
assert vec.shape == (8,)
|
tests/test_integration_synthetic.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — Synthetic Integration Test
|
| 3 |
+
Full pipeline E2E with synthetic tensors — no real model needed.
|
| 4 |
+
|
| 5 |
+
Pipeline: create KV → extract state → serialize .eng → load → index → query → retrieve
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import pytest
|
| 13 |
+
import torch
|
| 14 |
+
from safetensors.torch import load_file
|
| 15 |
+
|
| 16 |
+
from kvcos.core.cache_spec import LLAMA_3_1_8B
|
| 17 |
+
from kvcos.core.serializer import EngramSerializer
|
| 18 |
+
from kvcos.core.types import CompressionMethod, StateExtractionMode
|
| 19 |
+
from kvcos.core.manifold_index import ManifoldIndex
|
| 20 |
+
from kvcos.core.retriever import EGRRetriever
|
| 21 |
+
from kvcos.core.state_extractor import MARStateExtractor
|
| 22 |
+
from kvcos.storage.local import LocalStorageBackend
|
| 23 |
+
from tests.conftest import make_synthetic_kv
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TestFullPipeline:
|
| 27 |
+
"""End-to-end: store → index → query → retrieve using synthetic data."""
|
| 28 |
+
|
| 29 |
+
def test_serialize_round_trip(self, tmp_data_dir: Path) -> None:
|
| 30 |
+
"""Step 1-4: Create → serialize → load → verify shape."""
|
| 31 |
+
keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256)
|
| 32 |
+
assert keys.shape == (32, 8, 256, 128)
|
| 33 |
+
|
| 34 |
+
serializer = EngramSerializer()
|
| 35 |
+
eng_path = tmp_data_dir / "roundtrip.eng"
|
| 36 |
+
|
| 37 |
+
serializer.serialize(
|
| 38 |
+
keys=keys, values=values,
|
| 39 |
+
agent_id="integration-test", task_description="round-trip test",
|
| 40 |
+
model_id=LLAMA_3_1_8B["model_id"], output_path=eng_path,
|
| 41 |
+
compression=CompressionMethod.FP16,
|
| 42 |
+
)
|
| 43 |
+
assert eng_path.exists()
|
| 44 |
+
|
| 45 |
+
# Verify valid safetensors
|
| 46 |
+
tensors = load_file(str(eng_path))
|
| 47 |
+
assert "layer_0_keys" in tensors
|
| 48 |
+
|
| 49 |
+
k_out, v_out, meta = serializer.deserialize(eng_path)
|
| 50 |
+
assert k_out.shape == keys.shape
|
| 51 |
+
assert v_out.shape == values.shape
|
| 52 |
+
|
| 53 |
+
@pytest.mark.parametrize("mode", list(StateExtractionMode))
|
| 54 |
+
def test_extraction_all_modes(self, mode: StateExtractionMode) -> None:
|
| 55 |
+
"""Step 2: Extract state vector in all 3 modes."""
|
| 56 |
+
keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256)
|
| 57 |
+
extractor = MARStateExtractor(mode=mode, rank=128)
|
| 58 |
+
result = extractor.extract(keys, LLAMA_3_1_8B)
|
| 59 |
+
|
| 60 |
+
assert result.state_vec.dim() == 1
|
| 61 |
+
assert result.state_vec.shape[0] > 0
|
| 62 |
+
assert result.l2_norm > 0
|
| 63 |
+
assert result.mode == mode
|
| 64 |
+
|
| 65 |
+
def test_index_and_query(self, tmp_data_dir: Path) -> None:
|
| 66 |
+
"""Step 5-6: Index state vector → query with different tensor → get result."""
|
| 67 |
+
keys_a, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256, seed=42)
|
| 68 |
+
keys_b, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256, seed=99)
|
| 69 |
+
|
| 70 |
+
extractor = MARStateExtractor(
|
| 71 |
+
mode=StateExtractionMode.MEAN_POOL,
|
| 72 |
+
)
|
| 73 |
+
dim = extractor.output_dim(LLAMA_3_1_8B)
|
| 74 |
+
index = ManifoldIndex(dim=dim)
|
| 75 |
+
|
| 76 |
+
# Extract and index first tensor
|
| 77 |
+
from kvcos.core.manifold_index import IndexEntry
|
| 78 |
+
|
| 79 |
+
result_a = extractor.extract(keys_a, LLAMA_3_1_8B)
|
| 80 |
+
index.add(
|
| 81 |
+
result_a.state_vec,
|
| 82 |
+
IndexEntry(
|
| 83 |
+
cache_id="test-cache-a",
|
| 84 |
+
task_description="indexed engram",
|
| 85 |
+
model_id=LLAMA_3_1_8B["model_id"],
|
| 86 |
+
created_at="2026-01-01T00:00:00Z",
|
| 87 |
+
context_len=256,
|
| 88 |
+
l2_norm=result_a.l2_norm,
|
| 89 |
+
),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Query with second tensor
|
| 93 |
+
result_b = extractor.extract(keys_b, LLAMA_3_1_8B)
|
| 94 |
+
results = index.search(result_b.state_vec, top_k=1)
|
| 95 |
+
|
| 96 |
+
assert len(results) >= 1
|
| 97 |
+
assert results[0]["cache_id"] == "test-cache-a"
|
| 98 |
+
|
| 99 |
+
def test_full_egr_pipeline(self, tmp_data_dir: Path) -> None:
|
| 100 |
+
"""Step 7: Full EGR retrieval — store → index → query → retrieve."""
|
| 101 |
+
keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256, seed=42)
|
| 102 |
+
query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256, seed=99)
|
| 103 |
+
|
| 104 |
+
extractor = MARStateExtractor(
|
| 105 |
+
mode=StateExtractionMode.MEAN_POOL,
|
| 106 |
+
)
|
| 107 |
+
dim = extractor.output_dim(LLAMA_3_1_8B)
|
| 108 |
+
index = ManifoldIndex(dim=dim)
|
| 109 |
+
storage = LocalStorageBackend(data_dir=tmp_data_dir)
|
| 110 |
+
retriever = EGRRetriever(extractor, index, storage)
|
| 111 |
+
|
| 112 |
+
# Store
|
| 113 |
+
cache_id = retriever.index_engram(
|
| 114 |
+
keys=keys, values=values, spec=LLAMA_3_1_8B,
|
| 115 |
+
agent_id="integration-test",
|
| 116 |
+
task_description="full pipeline test",
|
| 117 |
+
model_id=LLAMA_3_1_8B["model_id"],
|
| 118 |
+
output_dir=tmp_data_dir,
|
| 119 |
+
)
|
| 120 |
+
assert isinstance(cache_id, str)
|
| 121 |
+
assert index.n_entries == 1
|
| 122 |
+
|
| 123 |
+
# Retrieve
|
| 124 |
+
response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=1)
|
| 125 |
+
assert len(response.results) >= 1
|
| 126 |
+
|
| 127 |
+
result = response.results[0]
|
| 128 |
+
assert result.cache_id == cache_id
|
| 129 |
+
assert result.keys.shape == keys.shape
|
| 130 |
+
assert result.values.shape == values.shape
|
| 131 |
+
assert result.similarity != 0.0
|
| 132 |
+
|
| 133 |
+
def test_multi_engram_ranking(self, tmp_data_dir: Path) -> None:
|
| 134 |
+
"""Store 3 engrams, query, verify results are ranked by similarity."""
|
| 135 |
+
extractor = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
|
| 136 |
+
dim = extractor.output_dim(LLAMA_3_1_8B)
|
| 137 |
+
index = ManifoldIndex(dim=dim)
|
| 138 |
+
storage = LocalStorageBackend(data_dir=tmp_data_dir)
|
| 139 |
+
retriever = EGRRetriever(extractor, index, storage)
|
| 140 |
+
|
| 141 |
+
for seed in (10, 20, 30):
|
| 142 |
+
keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64, seed=seed)
|
| 143 |
+
retriever.index_engram(
|
| 144 |
+
keys=keys, values=values, spec=LLAMA_3_1_8B,
|
| 145 |
+
agent_id="test", task_description=f"seed-{seed}",
|
| 146 |
+
model_id=LLAMA_3_1_8B["model_id"],
|
| 147 |
+
output_dir=tmp_data_dir,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
assert index.n_entries == 3
|
| 151 |
+
|
| 152 |
+
query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64, seed=10)
|
| 153 |
+
response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=3)
|
| 154 |
+
|
| 155 |
+
assert len(response.results) == 3
|
| 156 |
+
# Results should be sorted by descending similarity
|
| 157 |
+
sims = [r.similarity for r in response.results]
|
| 158 |
+
assert sims == sorted(sims, reverse=True)
|
| 159 |
+
# Closest match should be seed=10 (same as query)
|
| 160 |
+
assert response.results[0].task_description == "seed-10"
|
tests/test_iswa_blob_parser.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — ISWA Blob Parser Tests
|
| 3 |
+
Tests for multi-section KV cache parsing (Gemma 4 ISWA format).
|
| 4 |
+
|
| 5 |
+
Uses synthetic ISWA blobs from conftest.make_synthetic_iswa_blob().
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import pytest
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from kvcos.core.blob_parser import (
|
| 14 |
+
BlobParseError,
|
| 15 |
+
ParsedKVCache,
|
| 16 |
+
ParsedMultiSectionCache,
|
| 17 |
+
parse_multi_section_blob,
|
| 18 |
+
parse_state_blob,
|
| 19 |
+
)
|
| 20 |
+
from kvcos.core.types import AttentionType, CacheSection, ModelCacheSpec
|
| 21 |
+
from tests.conftest import (
|
| 22 |
+
GEMMA4_GLOBAL_SECTION,
|
| 23 |
+
GEMMA4_SECTIONS,
|
| 24 |
+
GEMMA4_SWA_SECTION,
|
| 25 |
+
make_synthetic_iswa_blob,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TestParseMultiSectionBlob:
|
| 30 |
+
"""Parse ISWA blobs with multiple KV cache sections."""
|
| 31 |
+
|
| 32 |
+
def test_parse_gemma4_shape(self) -> None:
|
| 33 |
+
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
|
| 34 |
+
result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
|
| 35 |
+
|
| 36 |
+
assert len(result.sections) == 2
|
| 37 |
+
|
| 38 |
+
# Global section: [5, 2, 4, 512]
|
| 39 |
+
s0 = result.sections[0]
|
| 40 |
+
assert s0.keys.shape == (5, 2, 4, 512)
|
| 41 |
+
assert s0.values.shape == (5, 2, 4, 512)
|
| 42 |
+
|
| 43 |
+
# SWA section: [25, 8, 4, 256]
|
| 44 |
+
s1 = result.sections[1]
|
| 45 |
+
assert s1.keys.shape == (25, 8, 4, 256)
|
| 46 |
+
assert s1.values.shape == (25, 8, 4, 256)
|
| 47 |
+
|
| 48 |
+
def test_parse_metadata(self) -> None:
|
| 49 |
+
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
|
| 50 |
+
result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
|
| 51 |
+
|
| 52 |
+
assert result.arch == "gemma4"
|
| 53 |
+
assert result.n_sections == 2
|
| 54 |
+
assert result.total_layers == 30
|
| 55 |
+
|
| 56 |
+
assert result.sections[0].n_layers == 5
|
| 57 |
+
assert result.sections[0].arch == "gemma4"
|
| 58 |
+
assert result.sections[1].n_layers == 25
|
| 59 |
+
|
| 60 |
+
def test_parse_cells(self) -> None:
|
| 61 |
+
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
|
| 62 |
+
result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
|
| 63 |
+
|
| 64 |
+
for sec in result.sections:
|
| 65 |
+
assert sec.n_cells == 4
|
| 66 |
+
assert len(sec.cells) == 4
|
| 67 |
+
assert sec.cells[0].pos == 0
|
| 68 |
+
assert sec.cells[3].pos == 3
|
| 69 |
+
|
| 70 |
+
def test_dtype_float16(self) -> None:
|
| 71 |
+
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=2)
|
| 72 |
+
result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
|
| 73 |
+
|
| 74 |
+
for sec in result.sections:
|
| 75 |
+
assert sec.keys.dtype == torch.float16
|
| 76 |
+
assert sec.values.dtype == torch.float16
|
| 77 |
+
|
| 78 |
+
def test_different_cell_counts(self) -> None:
|
| 79 |
+
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=8)
|
| 80 |
+
result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
|
| 81 |
+
|
| 82 |
+
assert result.sections[0].n_cells == 8
|
| 83 |
+
assert result.sections[1].n_cells == 8
|
| 84 |
+
|
| 85 |
+
def test_non_transposed_v(self) -> None:
|
| 86 |
+
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=2, v_trans=False)
|
| 87 |
+
result = parse_multi_section_blob(blob, GEMMA4_SECTIONS)
|
| 88 |
+
|
| 89 |
+
for sec in result.sections:
|
| 90 |
+
assert sec.v_trans is False
|
| 91 |
+
|
| 92 |
+
def test_single_section_works(self) -> None:
|
| 93 |
+
"""Single-section ISWA parse should work identically to standard."""
|
| 94 |
+
single = (GEMMA4_GLOBAL_SECTION,)
|
| 95 |
+
blob = make_synthetic_iswa_blob(single, n_cells=4)
|
| 96 |
+
result = parse_multi_section_blob(blob, single)
|
| 97 |
+
|
| 98 |
+
assert len(result.sections) == 1
|
| 99 |
+
assert result.sections[0].keys.shape == (5, 2, 4, 512)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class TestParseMultiSectionErrors:
|
| 103 |
+
"""Error handling for ISWA blob parsing."""
|
| 104 |
+
|
| 105 |
+
def test_section_mismatch_raises(self) -> None:
|
| 106 |
+
"""Blob has 2 sections but we pass specs for 3."""
|
| 107 |
+
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
|
| 108 |
+
three_sections = GEMMA4_SECTIONS + (GEMMA4_GLOBAL_SECTION,)
|
| 109 |
+
with pytest.raises(BlobParseError, match="Expected 3.*got 2"):
|
| 110 |
+
parse_multi_section_blob(blob, three_sections)
|
| 111 |
+
|
| 112 |
+
def test_truncated_blob_raises(self) -> None:
|
| 113 |
+
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
|
| 114 |
+
with pytest.raises(BlobParseError):
|
| 115 |
+
parse_multi_section_blob(blob[:100], GEMMA4_SECTIONS)
|
| 116 |
+
|
| 117 |
+
def test_wrong_dimensions_raises(self) -> None:
|
| 118 |
+
"""Pass wrong KV head count for a section."""
|
| 119 |
+
wrong_sections = (
|
| 120 |
+
CacheSection(AttentionType.FULL, 5, 4, 512), # wrong: 4 heads not 2
|
| 121 |
+
GEMMA4_SWA_SECTION,
|
| 122 |
+
)
|
| 123 |
+
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4)
|
| 124 |
+
with pytest.raises(BlobParseError):
|
| 125 |
+
parse_multi_section_blob(blob, wrong_sections)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class TestStandardBlobBackwardCompat:
|
| 129 |
+
"""Ensure parse_state_blob still works for single-stream blobs."""
|
| 130 |
+
|
| 131 |
+
def test_single_stream_still_works(self) -> None:
|
| 132 |
+
from tests.test_blob_parser import _make_blob
|
| 133 |
+
|
| 134 |
+
blob = _make_blob(16, 32, 8, 128)
|
| 135 |
+
result = parse_state_blob(blob, n_kv_heads=8, head_dim=128)
|
| 136 |
+
assert result.keys.shape == (32, 8, 16, 128)
|
tests/test_iswa_bridge.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — ISWA Bridge Tests
|
| 3 |
+
Tests for multi-architecture metadata detection and ISWA cache extraction.
|
| 4 |
+
Does NOT require a real GGUF model — tests the metadata helpers and spec logic.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import pytest
|
| 10 |
+
|
| 11 |
+
from integrations.llama_cpp_bridge import _meta_get
|
| 12 |
+
from kvcos.core.cache_spec import (
|
| 13 |
+
GEMMA_4_26B_A4B,
|
| 14 |
+
LLAMA_3_1_8B,
|
| 15 |
+
get_model_spec,
|
| 16 |
+
is_iswa_spec,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TestMetaGet:
|
| 21 |
+
"""Metadata key fallback chain across architecture prefixes."""
|
| 22 |
+
|
| 23 |
+
def test_llama_prefix(self) -> None:
|
| 24 |
+
meta = {"llama.block_count": "32"}
|
| 25 |
+
assert _meta_get(meta, "block_count") == "32"
|
| 26 |
+
|
| 27 |
+
def test_gemma4_prefix(self) -> None:
|
| 28 |
+
meta = {"gemma4.block_count": "30"}
|
| 29 |
+
assert _meta_get(meta, "block_count") == "30"
|
| 30 |
+
|
| 31 |
+
def test_gemma_prefix(self) -> None:
|
| 32 |
+
meta = {"gemma.attention.head_count": "8"}
|
| 33 |
+
assert _meta_get(meta, "attention.head_count") == "8"
|
| 34 |
+
|
| 35 |
+
def test_general_fallback(self) -> None:
|
| 36 |
+
meta = {"general.block_count": "28"}
|
| 37 |
+
assert _meta_get(meta, "block_count") == "28"
|
| 38 |
+
|
| 39 |
+
def test_default_when_missing(self) -> None:
|
| 40 |
+
meta = {}
|
| 41 |
+
assert _meta_get(meta, "block_count", "32") == "32"
|
| 42 |
+
|
| 43 |
+
def test_llama_takes_priority(self) -> None:
|
| 44 |
+
meta = {
|
| 45 |
+
"llama.block_count": "32",
|
| 46 |
+
"gemma4.block_count": "30",
|
| 47 |
+
"general.block_count": "28",
|
| 48 |
+
}
|
| 49 |
+
assert _meta_get(meta, "block_count") == "32"
|
| 50 |
+
|
| 51 |
+
def test_gemma4_before_general(self) -> None:
|
| 52 |
+
meta = {
|
| 53 |
+
"gemma4.embedding_length": "3072",
|
| 54 |
+
"general.embedding_length": "4096",
|
| 55 |
+
}
|
| 56 |
+
assert _meta_get(meta, "embedding_length") == "3072"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class TestISWASpecDetection:
|
| 60 |
+
"""Registry and ISWA detection."""
|
| 61 |
+
|
| 62 |
+
def test_gemma4_in_registry(self) -> None:
|
| 63 |
+
spec = get_model_spec("google/gemma-4-26b-a4b-it")
|
| 64 |
+
assert spec is not None
|
| 65 |
+
assert spec["model_family"] == "gemma"
|
| 66 |
+
|
| 67 |
+
def test_gemma4_is_iswa(self) -> None:
|
| 68 |
+
assert is_iswa_spec(GEMMA_4_26B_A4B) is True
|
| 69 |
+
|
| 70 |
+
def test_llama_not_iswa(self) -> None:
|
| 71 |
+
assert is_iswa_spec(LLAMA_3_1_8B) is False
|
| 72 |
+
|
| 73 |
+
def test_gemma4_sections_correct(self) -> None:
|
| 74 |
+
sections = GEMMA_4_26B_A4B["cache_sections"]
|
| 75 |
+
assert len(sections) == 2
|
| 76 |
+
|
| 77 |
+
# Global section
|
| 78 |
+
assert sections[0].n_layers == 5
|
| 79 |
+
assert sections[0].n_kv_heads == 2
|
| 80 |
+
assert sections[0].head_dim == 512
|
| 81 |
+
assert sections[0].n_embd_kv == 1024
|
| 82 |
+
|
| 83 |
+
# SWA section
|
| 84 |
+
assert sections[1].n_layers == 25
|
| 85 |
+
assert sections[1].n_kv_heads == 8
|
| 86 |
+
assert sections[1].head_dim == 256
|
| 87 |
+
assert sections[1].window_size == 1024
|
| 88 |
+
|
| 89 |
+
def test_gemma4_total_layers(self) -> None:
|
| 90 |
+
sections = GEMMA_4_26B_A4B["cache_sections"]
|
| 91 |
+
total = sum(s.n_layers for s in sections)
|
| 92 |
+
assert total == GEMMA_4_26B_A4B["n_layers"]
|
tests/test_iswa_fingerprint.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — ISWA Fingerprint Tests
|
| 3 |
+
Tests for per-section Fourier fingerprint computation and concatenation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from kvcos.core.blob_parser import ParsedKVCache, ParsedMultiSectionCache, parse_multi_section_blob
|
| 11 |
+
from kvcos.core.fingerprint import (
|
| 12 |
+
compute_fourier_fingerprint_v2,
|
| 13 |
+
compute_iswa_fingerprint,
|
| 14 |
+
)
|
| 15 |
+
from kvcos.core.types import AttentionType, CacheSection
|
| 16 |
+
from tests.conftest import GEMMA4_SECTIONS, make_synthetic_iswa_blob
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TestISWAFingerprint:
|
| 20 |
+
"""Per-section fingerprint computation for ISWA models."""
|
| 21 |
+
|
| 22 |
+
def _make_parsed(self, n_cells: int = 4) -> ParsedMultiSectionCache:
|
| 23 |
+
blob = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=n_cells)
|
| 24 |
+
return parse_multi_section_blob(blob, GEMMA4_SECTIONS)
|
| 25 |
+
|
| 26 |
+
def test_fingerprint_shape(self) -> None:
|
| 27 |
+
parsed = self._make_parsed()
|
| 28 |
+
fp = compute_iswa_fingerprint(parsed, freqs=[0, 1])
|
| 29 |
+
|
| 30 |
+
# Global: 2 * 512 * 2 = 2048
|
| 31 |
+
# SWA: 8 * 256 * 2 = 4096
|
| 32 |
+
# Total: 6144
|
| 33 |
+
assert fp.shape == (6144,)
|
| 34 |
+
|
| 35 |
+
def test_fingerprint_dtype(self) -> None:
|
| 36 |
+
parsed = self._make_parsed()
|
| 37 |
+
fp = compute_iswa_fingerprint(parsed)
|
| 38 |
+
assert fp.dtype == torch.float32
|
| 39 |
+
|
| 40 |
+
def test_fingerprint_normalized(self) -> None:
|
| 41 |
+
"""Each section's sub-FP is concat of per-freq L2-normalized vectors."""
|
| 42 |
+
parsed = self._make_parsed()
|
| 43 |
+
fp = compute_iswa_fingerprint(parsed, freqs=[0, 1])
|
| 44 |
+
|
| 45 |
+
# Global section FP: first 2048 dims (1024 per freq, 2 freqs)
|
| 46 |
+
global_fp = fp[:2048]
|
| 47 |
+
# SWA section FP: next 4096 dims (2048 per freq, 2 freqs)
|
| 48 |
+
swa_fp = fp[2048:]
|
| 49 |
+
|
| 50 |
+
# Each sub-section is 2 concatenated unit vectors → norm = sqrt(2)
|
| 51 |
+
import math
|
| 52 |
+
expected_norm = math.sqrt(2)
|
| 53 |
+
assert abs(global_fp.norm().item() - expected_norm) < 0.05
|
| 54 |
+
assert abs(swa_fp.norm().item() - expected_norm) < 0.05
|
| 55 |
+
|
| 56 |
+
def test_deterministic(self) -> None:
|
| 57 |
+
parsed = self._make_parsed()
|
| 58 |
+
fp1 = compute_iswa_fingerprint(parsed)
|
| 59 |
+
fp2 = compute_iswa_fingerprint(parsed)
|
| 60 |
+
assert torch.allclose(fp1, fp2)
|
| 61 |
+
|
| 62 |
+
def test_different_inputs_differ(self) -> None:
|
| 63 |
+
p1 = self._make_parsed(n_cells=4)
|
| 64 |
+
blob2 = make_synthetic_iswa_blob(GEMMA4_SECTIONS, n_cells=4, seed=999)
|
| 65 |
+
p2 = parse_multi_section_blob(blob2, GEMMA4_SECTIONS)
|
| 66 |
+
|
| 67 |
+
fp1 = compute_iswa_fingerprint(p1)
|
| 68 |
+
fp2 = compute_iswa_fingerprint(p2)
|
| 69 |
+
cos = torch.nn.functional.cosine_similarity(fp1.unsqueeze(0), fp2.unsqueeze(0))
|
| 70 |
+
assert cos.item() < 0.99 # different inputs → different FPs
|
| 71 |
+
|
| 72 |
+
def test_single_section_matches_standard(self) -> None:
|
| 73 |
+
"""Single-section ISWA FP should match standard FP."""
|
| 74 |
+
section = CacheSection(AttentionType.FULL, 5, 2, 512)
|
| 75 |
+
blob = make_synthetic_iswa_blob((section,), n_cells=4)
|
| 76 |
+
parsed = parse_multi_section_blob(blob, (section,))
|
| 77 |
+
|
| 78 |
+
iswa_fp = compute_iswa_fingerprint(parsed, freqs=[0, 1])
|
| 79 |
+
|
| 80 |
+
# Compare with standard FP on same data
|
| 81 |
+
layer_keys = parsed.sections[0].keys.float().mean(dim=2)
|
| 82 |
+
standard_fp = compute_fourier_fingerprint_v2(layer_keys, freqs=[0, 1])
|
| 83 |
+
|
| 84 |
+
assert torch.allclose(iswa_fp, standard_fp, atol=1e-5)
|
| 85 |
+
|
| 86 |
+
def test_custom_freqs(self) -> None:
|
| 87 |
+
parsed = self._make_parsed()
|
| 88 |
+
fp_f0 = compute_iswa_fingerprint(parsed, freqs=[0])
|
| 89 |
+
fp_f01 = compute_iswa_fingerprint(parsed, freqs=[0, 1])
|
| 90 |
+
|
| 91 |
+
# f0 only: Global(1024) + SWA(2048) = 3072
|
| 92 |
+
assert fp_f0.shape == (3072,)
|
| 93 |
+
# f0+f1: double
|
| 94 |
+
assert fp_f01.shape == (6144,)
|
tests/test_iswa_types.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — ISWA Type System Tests
|
| 3 |
+
Tests for CacheSection, AttentionType, and ISWA-aware ModelCacheSpec.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
from kvcos.core.types import AttentionType, CacheSection, ModelCacheSpec
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TestAttentionType:
|
| 14 |
+
"""AttentionType enum values."""
|
| 15 |
+
|
| 16 |
+
def test_full_value(self) -> None:
|
| 17 |
+
assert AttentionType.FULL == "full"
|
| 18 |
+
|
| 19 |
+
def test_sliding_value(self) -> None:
|
| 20 |
+
assert AttentionType.SLIDING == "sliding"
|
| 21 |
+
|
| 22 |
+
def test_is_str(self) -> None:
|
| 23 |
+
assert isinstance(AttentionType.FULL, str)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TestCacheSection:
|
| 27 |
+
"""CacheSection frozen dataclass."""
|
| 28 |
+
|
| 29 |
+
def test_global_section(self) -> None:
|
| 30 |
+
sec = CacheSection(
|
| 31 |
+
attention_type=AttentionType.FULL,
|
| 32 |
+
n_layers=5,
|
| 33 |
+
n_kv_heads=2,
|
| 34 |
+
head_dim=512,
|
| 35 |
+
)
|
| 36 |
+
assert sec.n_layers == 5
|
| 37 |
+
assert sec.n_kv_heads == 2
|
| 38 |
+
assert sec.head_dim == 512
|
| 39 |
+
assert sec.window_size is None
|
| 40 |
+
|
| 41 |
+
def test_sliding_section(self) -> None:
|
| 42 |
+
sec = CacheSection(
|
| 43 |
+
attention_type=AttentionType.SLIDING,
|
| 44 |
+
n_layers=25,
|
| 45 |
+
n_kv_heads=8,
|
| 46 |
+
head_dim=256,
|
| 47 |
+
window_size=1024,
|
| 48 |
+
)
|
| 49 |
+
assert sec.attention_type == AttentionType.SLIDING
|
| 50 |
+
assert sec.window_size == 1024
|
| 51 |
+
|
| 52 |
+
def test_n_embd_kv(self) -> None:
|
| 53 |
+
sec = CacheSection(
|
| 54 |
+
attention_type=AttentionType.FULL,
|
| 55 |
+
n_layers=5,
|
| 56 |
+
n_kv_heads=2,
|
| 57 |
+
head_dim=512,
|
| 58 |
+
)
|
| 59 |
+
assert sec.n_embd_kv == 1024 # 2 * 512
|
| 60 |
+
|
| 61 |
+
def test_frozen(self) -> None:
|
| 62 |
+
sec = CacheSection(
|
| 63 |
+
attention_type=AttentionType.FULL,
|
| 64 |
+
n_layers=5,
|
| 65 |
+
n_kv_heads=2,
|
| 66 |
+
head_dim=512,
|
| 67 |
+
)
|
| 68 |
+
with pytest.raises(AttributeError):
|
| 69 |
+
sec.n_layers = 10 # type: ignore[misc]
|
| 70 |
+
|
| 71 |
+
def test_equality(self) -> None:
|
| 72 |
+
a = CacheSection(AttentionType.FULL, 5, 2, 512)
|
| 73 |
+
b = CacheSection(AttentionType.FULL, 5, 2, 512)
|
| 74 |
+
assert a == b
|
| 75 |
+
|
| 76 |
+
def test_inequality(self) -> None:
|
| 77 |
+
a = CacheSection(AttentionType.FULL, 5, 2, 512)
|
| 78 |
+
b = CacheSection(AttentionType.SLIDING, 25, 8, 256, 1024)
|
| 79 |
+
assert a != b
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TestModelCacheSpecISWA:
|
| 83 |
+
"""ModelCacheSpec with optional cache_sections."""
|
| 84 |
+
|
| 85 |
+
def test_standard_spec_no_sections(self) -> None:
|
| 86 |
+
spec = ModelCacheSpec(
|
| 87 |
+
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
| 88 |
+
model_family="llama",
|
| 89 |
+
n_layers=32,
|
| 90 |
+
n_heads=32,
|
| 91 |
+
n_kv_heads=8,
|
| 92 |
+
head_dim=128,
|
| 93 |
+
rope_enabled=True,
|
| 94 |
+
extraction_layers=tuple(range(8, 32)),
|
| 95 |
+
)
|
| 96 |
+
assert "cache_sections" not in spec
|
| 97 |
+
assert spec["n_kv_heads"] == 8
|
| 98 |
+
|
| 99 |
+
def test_iswa_spec_with_sections(self) -> None:
|
| 100 |
+
sections = (
|
| 101 |
+
CacheSection(AttentionType.FULL, 5, 2, 512),
|
| 102 |
+
CacheSection(AttentionType.SLIDING, 25, 8, 256, 1024),
|
| 103 |
+
)
|
| 104 |
+
spec = ModelCacheSpec(
|
| 105 |
+
model_id="google/gemma-4-26b-a4b-it",
|
| 106 |
+
model_family="gemma",
|
| 107 |
+
n_layers=30,
|
| 108 |
+
n_heads=32,
|
| 109 |
+
n_kv_heads=8,
|
| 110 |
+
head_dim=256,
|
| 111 |
+
rope_enabled=True,
|
| 112 |
+
extraction_layers=tuple(range(8, 30)),
|
| 113 |
+
cache_sections=sections,
|
| 114 |
+
)
|
| 115 |
+
assert "cache_sections" in spec
|
| 116 |
+
assert len(spec["cache_sections"]) == 2
|
| 117 |
+
assert spec["cache_sections"][0].n_embd_kv == 1024
|
| 118 |
+
assert spec["cache_sections"][1].n_embd_kv == 2048
|
| 119 |
+
|
| 120 |
+
def test_iswa_total_layers_match(self) -> None:
|
| 121 |
+
sections = (
|
| 122 |
+
CacheSection(AttentionType.FULL, 5, 2, 512),
|
| 123 |
+
CacheSection(AttentionType.SLIDING, 25, 8, 256, 1024),
|
| 124 |
+
)
|
| 125 |
+
total = sum(s.n_layers for s in sections)
|
| 126 |
+
assert total == 30
|
| 127 |
+
|
| 128 |
+
def test_backward_compat_existing_specs(self) -> None:
|
| 129 |
+
"""Existing specs without cache_sections still work."""
|
| 130 |
+
from kvcos.core.cache_spec import LLAMA_3_1_8B
|
| 131 |
+
assert LLAMA_3_1_8B["n_kv_heads"] == 8
|
| 132 |
+
assert "cache_sections" not in LLAMA_3_1_8B
|
tests/test_knowledge_index.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for kvcos.engram.knowledge_index — HNSW knowledge search."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from kvcos.engram.embedder import get_fingerprint
|
| 10 |
+
from kvcos.engram.format import EigramEncoder
|
| 11 |
+
from kvcos.engram.knowledge_index import KnowledgeIndex
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@pytest.fixture
|
| 15 |
+
def knowledge_dir(tmp_path):
|
| 16 |
+
"""Create a temporary knowledge directory with test .eng files."""
|
| 17 |
+
encoder = EigramEncoder()
|
| 18 |
+
project_dir = tmp_path / "test_project"
|
| 19 |
+
project_dir.mkdir()
|
| 20 |
+
|
| 21 |
+
docs = [
|
| 22 |
+
("doc_ml", "Machine learning model training and optimization"),
|
| 23 |
+
("doc_db", "PostgreSQL database schema migration tools"),
|
| 24 |
+
("doc_api", "REST API endpoint authentication and authorization"),
|
| 25 |
+
("doc_test", "Unit testing with pytest fixtures and mocking"),
|
| 26 |
+
("doc_deploy", "Docker container deployment to Kubernetes cluster"),
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
for doc_id, text in docs:
|
| 30 |
+
fp, source = get_fingerprint(text)
|
| 31 |
+
dim = fp.shape[0]
|
| 32 |
+
|
| 33 |
+
blob = encoder.encode(
|
| 34 |
+
vec_perdoc=torch.zeros(116),
|
| 35 |
+
vec_fcdb=torch.zeros(116),
|
| 36 |
+
joint_center=torch.zeros(128),
|
| 37 |
+
corpus_hash="test" * 8,
|
| 38 |
+
model_id=source[:16],
|
| 39 |
+
basis_rank=116,
|
| 40 |
+
n_corpus=0,
|
| 41 |
+
layer_range=(0, 0),
|
| 42 |
+
context_len=len(text),
|
| 43 |
+
l2_norm=float(torch.norm(fp).item()),
|
| 44 |
+
scs=0.0,
|
| 45 |
+
margin_proof=0.0,
|
| 46 |
+
task_description=text[:256],
|
| 47 |
+
cache_id=doc_id,
|
| 48 |
+
vec_fourier=fp if dim == 2048 else None,
|
| 49 |
+
vec_fourier_v2=fp,
|
| 50 |
+
confusion_flag=False,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
eng_path = project_dir / f"{doc_id}.eng"
|
| 54 |
+
eng_path.write_bytes(blob)
|
| 55 |
+
|
| 56 |
+
meta = {
|
| 57 |
+
"cache_id": doc_id,
|
| 58 |
+
"task_description": text,
|
| 59 |
+
"source_path": f"/test/{doc_id}.md",
|
| 60 |
+
"project": "test_project",
|
| 61 |
+
"fp_source": source,
|
| 62 |
+
"chunk_index": 0,
|
| 63 |
+
"chunk_total": 1,
|
| 64 |
+
"headers": [],
|
| 65 |
+
}
|
| 66 |
+
meta_path = Path(str(eng_path) + ".meta.json")
|
| 67 |
+
meta_path.write_text(json.dumps(meta))
|
| 68 |
+
|
| 69 |
+
return tmp_path
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class TestKnowledgeIndexBuild:
|
| 73 |
+
def test_build_from_directory(self, knowledge_dir):
|
| 74 |
+
kidx = KnowledgeIndex.build_from_knowledge_dir(
|
| 75 |
+
knowledge_dir, verbose=False
|
| 76 |
+
)
|
| 77 |
+
assert len(kidx) == 5
|
| 78 |
+
|
| 79 |
+
def test_build_empty_directory(self, tmp_path):
|
| 80 |
+
with pytest.raises(ValueError, match="No .eng files"):
|
| 81 |
+
KnowledgeIndex.build_from_knowledge_dir(tmp_path, verbose=False)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class TestKnowledgeIndexSearch:
|
| 85 |
+
def test_search_returns_results(self, knowledge_dir):
|
| 86 |
+
kidx = KnowledgeIndex.build_from_knowledge_dir(
|
| 87 |
+
knowledge_dir, verbose=False
|
| 88 |
+
)
|
| 89 |
+
results = kidx.search("database query optimization", k=3)
|
| 90 |
+
assert len(results) == 3
|
| 91 |
+
assert all(r.score > 0 for r in results)
|
| 92 |
+
|
| 93 |
+
def test_search_result_fields(self, knowledge_dir):
|
| 94 |
+
kidx = KnowledgeIndex.build_from_knowledge_dir(
|
| 95 |
+
knowledge_dir, verbose=False
|
| 96 |
+
)
|
| 97 |
+
results = kidx.search("testing", k=1)
|
| 98 |
+
r = results[0]
|
| 99 |
+
assert r.doc_id
|
| 100 |
+
assert isinstance(r.score, float)
|
| 101 |
+
assert r.rank == 0
|
| 102 |
+
assert r.project == "test_project"
|
| 103 |
+
|
| 104 |
+
def test_search_with_tensor(self, knowledge_dir):
|
| 105 |
+
kidx = KnowledgeIndex.build_from_knowledge_dir(
|
| 106 |
+
knowledge_dir, verbose=False
|
| 107 |
+
)
|
| 108 |
+
query_fp, _ = get_fingerprint("unit tests")
|
| 109 |
+
results = kidx.search(query_fp, k=2)
|
| 110 |
+
assert len(results) == 2
|
| 111 |
+
|
| 112 |
+
def test_search_margin(self, knowledge_dir):
|
| 113 |
+
kidx = KnowledgeIndex.build_from_knowledge_dir(
|
| 114 |
+
knowledge_dir, verbose=False
|
| 115 |
+
)
|
| 116 |
+
results = kidx.search("testing", k=3)
|
| 117 |
+
# Top result should have a margin
|
| 118 |
+
assert results[0].margin >= 0
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class TestKnowledgeIndexPersistence:
|
| 122 |
+
def test_save_and_load(self, knowledge_dir, tmp_path):
|
| 123 |
+
kidx = KnowledgeIndex.build_from_knowledge_dir(
|
| 124 |
+
knowledge_dir, verbose=False
|
| 125 |
+
)
|
| 126 |
+
index_dir = tmp_path / "index"
|
| 127 |
+
kidx.save(index_dir)
|
| 128 |
+
|
| 129 |
+
loaded = KnowledgeIndex.load(index_dir)
|
| 130 |
+
assert len(loaded) == len(kidx)
|
| 131 |
+
|
| 132 |
+
# Search should work on loaded index
|
| 133 |
+
results = loaded.search("database", k=2)
|
| 134 |
+
assert len(results) == 2
|
| 135 |
+
|
| 136 |
+
def test_load_nonexistent(self, tmp_path):
|
| 137 |
+
with pytest.raises(FileNotFoundError):
|
| 138 |
+
KnowledgeIndex.load(tmp_path / "nonexistent")
|
tests/test_manifest.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for kvcos.engram.manifest — knowledge index registry."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import tempfile
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from kvcos.engram.manifest import ChunkRecord, Manifest, SourceRecord, _content_hash
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@pytest.fixture
|
| 13 |
+
def tmp_manifest(tmp_path):
|
| 14 |
+
"""Create a Manifest with a temporary path."""
|
| 15 |
+
return Manifest.load(tmp_path / "manifest.json")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestContentHash:
|
| 19 |
+
def test_deterministic(self):
|
| 20 |
+
assert _content_hash("hello") == _content_hash("hello")
|
| 21 |
+
|
| 22 |
+
def test_different_content(self):
|
| 23 |
+
assert _content_hash("hello") != _content_hash("world")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TestManifestLoad:
|
| 27 |
+
def test_load_nonexistent_creates_empty(self, tmp_path):
|
| 28 |
+
m = Manifest.load(tmp_path / "does_not_exist.json")
|
| 29 |
+
assert m.total_sources == 0
|
| 30 |
+
assert m.total_chunks == 0
|
| 31 |
+
|
| 32 |
+
def test_load_existing(self, tmp_path):
|
| 33 |
+
# Write a manifest, then load it
|
| 34 |
+
m = Manifest.load(tmp_path / "manifest.json")
|
| 35 |
+
m = m.register(
|
| 36 |
+
source_path="/test/file.md",
|
| 37 |
+
content_hash="abc123",
|
| 38 |
+
project="test",
|
| 39 |
+
file_size=100,
|
| 40 |
+
chunks=[ChunkRecord(
|
| 41 |
+
eng_path="/test/file.eng",
|
| 42 |
+
chunk_index=0,
|
| 43 |
+
chunk_total=1,
|
| 44 |
+
char_start=0,
|
| 45 |
+
char_end=100,
|
| 46 |
+
indexed_at=1000.0,
|
| 47 |
+
)],
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Load again from disk
|
| 51 |
+
m2 = Manifest.load(tmp_path / "manifest.json")
|
| 52 |
+
assert m2.total_sources == 1
|
| 53 |
+
assert m2.total_chunks == 1
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TestManifestRegister:
|
| 57 |
+
def test_register_new(self, tmp_manifest):
|
| 58 |
+
chunks = [ChunkRecord(
|
| 59 |
+
eng_path="/out/test.eng",
|
| 60 |
+
chunk_index=0,
|
| 61 |
+
chunk_total=1,
|
| 62 |
+
char_start=0,
|
| 63 |
+
char_end=50,
|
| 64 |
+
indexed_at=1000.0,
|
| 65 |
+
)]
|
| 66 |
+
m = tmp_manifest.register(
|
| 67 |
+
source_path="/src/test.md",
|
| 68 |
+
content_hash="hash1",
|
| 69 |
+
project="myproject",
|
| 70 |
+
file_size=50,
|
| 71 |
+
chunks=chunks,
|
| 72 |
+
)
|
| 73 |
+
assert m.total_sources == 1
|
| 74 |
+
assert m.total_chunks == 1
|
| 75 |
+
assert "myproject" in m.projects
|
| 76 |
+
|
| 77 |
+
def test_register_overwrites_existing(self, tmp_manifest):
|
| 78 |
+
chunks1 = [ChunkRecord(
|
| 79 |
+
eng_path="/out/v1.eng", chunk_index=0, chunk_total=1,
|
| 80 |
+
char_start=0, char_end=50, indexed_at=1000.0,
|
| 81 |
+
)]
|
| 82 |
+
m = tmp_manifest.register(
|
| 83 |
+
"/src/test.md", "hash1", "proj", 50, chunks1,
|
| 84 |
+
)
|
| 85 |
+
assert m.total_chunks == 1
|
| 86 |
+
|
| 87 |
+
chunks2 = [
|
| 88 |
+
ChunkRecord("/out/v2_1.eng", 0, 2, 0, 25, 2000.0),
|
| 89 |
+
ChunkRecord("/out/v2_2.eng", 1, 2, 25, 50, 2000.0),
|
| 90 |
+
]
|
| 91 |
+
m = m.register("/src/test.md", "hash2", "proj", 50, chunks2)
|
| 92 |
+
assert m.total_sources == 1 # still 1 source
|
| 93 |
+
assert m.total_chunks == 2 # now 2 chunks
|
| 94 |
+
|
| 95 |
+
def test_register_returns_new_manifest(self, tmp_manifest):
|
| 96 |
+
"""Register returns a new Manifest (immutability)."""
|
| 97 |
+
m1 = tmp_manifest
|
| 98 |
+
m2 = m1.register("/src/a.md", "h", "p", 10, [])
|
| 99 |
+
assert m1.total_sources == 0 # original unchanged
|
| 100 |
+
assert m2.total_sources == 1
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class TestManifestNeedsReindex:
|
| 104 |
+
def test_unknown_file_needs_index(self, tmp_manifest):
|
| 105 |
+
assert tmp_manifest.needs_reindex("/new/file.md", "any_hash")
|
| 106 |
+
|
| 107 |
+
def test_same_hash_no_reindex(self, tmp_manifest):
|
| 108 |
+
m = tmp_manifest.register("/src/a.md", "hash1", "p", 10, [])
|
| 109 |
+
assert not m.needs_reindex("/src/a.md", "hash1")
|
| 110 |
+
|
| 111 |
+
def test_different_hash_needs_reindex(self, tmp_manifest):
|
| 112 |
+
m = tmp_manifest.register("/src/a.md", "hash1", "p", 10, [])
|
| 113 |
+
assert m.needs_reindex("/src/a.md", "hash2")
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class TestManifestUnregister:
|
| 117 |
+
def test_unregister_existing(self, tmp_manifest):
|
| 118 |
+
m = tmp_manifest.register("/src/a.md", "h", "p", 10, [])
|
| 119 |
+
m = m.unregister("/src/a.md")
|
| 120 |
+
assert m.total_sources == 0
|
| 121 |
+
|
| 122 |
+
def test_unregister_nonexistent(self, tmp_manifest):
|
| 123 |
+
m = tmp_manifest.unregister("/not/here.md")
|
| 124 |
+
assert m.total_sources == 0
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class TestManifestQueries:
|
| 128 |
+
def test_get_project_records(self, tmp_manifest):
|
| 129 |
+
m = tmp_manifest
|
| 130 |
+
m = m.register("/a.md", "h1", "proj_a", 10, [])
|
| 131 |
+
m = m.register("/b.md", "h2", "proj_b", 20, [])
|
| 132 |
+
m = m.register("/c.md", "h3", "proj_a", 30, [])
|
| 133 |
+
|
| 134 |
+
a_recs = m.get_project_records("proj_a")
|
| 135 |
+
assert len(a_recs) == 2
|
| 136 |
+
|
| 137 |
+
def test_summary(self, tmp_manifest):
|
| 138 |
+
m = tmp_manifest.register("/a.md", "h", "p", 10, [
|
| 139 |
+
ChunkRecord("/a.eng", 0, 1, 0, 10, 1000.0),
|
| 140 |
+
])
|
| 141 |
+
s = m.summary()
|
| 142 |
+
assert s["total_sources"] == 1
|
| 143 |
+
assert s["total_chunks"] == 1
|
| 144 |
+
assert "p" in s["projects"]
|
| 145 |
+
|
| 146 |
+
def test_contains(self, tmp_manifest):
|
| 147 |
+
m = tmp_manifest.register("/a.md", "h", "p", 10, [])
|
| 148 |
+
assert "/a.md" in m
|
| 149 |
+
assert "/b.md" not in m
|
| 150 |
+
|
| 151 |
+
def test_len(self, tmp_manifest):
|
| 152 |
+
m = tmp_manifest.register("/a.md", "h", "p", 10, [])
|
| 153 |
+
assert len(m) == 1
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class TestManifestPersistence:
|
| 157 |
+
def test_atomic_write(self, tmp_path):
|
| 158 |
+
m = Manifest.load(tmp_path / "manifest.json")
|
| 159 |
+
m = m.register("/a.md", "h", "p", 10, [])
|
| 160 |
+
|
| 161 |
+
# File should exist
|
| 162 |
+
assert (tmp_path / "manifest.json").exists()
|
| 163 |
+
|
| 164 |
+
# Content should be valid JSON
|
| 165 |
+
data = json.loads((tmp_path / "manifest.json").read_text())
|
| 166 |
+
assert data["version"] == 1
|
| 167 |
+
assert len(data["sources"]) == 1
|
tests/test_manifold_index.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — Manifold Index Tests
|
| 3 |
+
Tests for FAISS IndexFlatIP add/search/remove/persist (D2, D4).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pytest
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from kvcos.core.manifold_index import IndexEntry, ManifoldIndex
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _entry(cid: str = "c1", model: str = "llama") -> IndexEntry:
|
| 18 |
+
return IndexEntry(
|
| 19 |
+
cache_id=cid, task_description="test",
|
| 20 |
+
model_id=model, created_at="2026-01-01T00:00:00Z",
|
| 21 |
+
context_len=256, l2_norm=1.0,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TestAddAndSearch:
|
| 26 |
+
"""Add vectors, search via MIPS."""
|
| 27 |
+
|
| 28 |
+
def test_add_increments(self) -> None:
|
| 29 |
+
idx = ManifoldIndex(dim=8)
|
| 30 |
+
idx.add(torch.randn(8), _entry("a"))
|
| 31 |
+
idx.add(torch.randn(8), _entry("b"))
|
| 32 |
+
assert idx.n_entries == 2
|
| 33 |
+
|
| 34 |
+
def test_search_returns_correct_order(self) -> None:
|
| 35 |
+
idx = ManifoldIndex(dim=4)
|
| 36 |
+
v1 = torch.tensor([1.0, 0.0, 0.0, 0.0])
|
| 37 |
+
v2 = torch.tensor([0.0, 1.0, 0.0, 0.0])
|
| 38 |
+
idx.add(v1, _entry("close"))
|
| 39 |
+
idx.add(v2, _entry("far"))
|
| 40 |
+
|
| 41 |
+
query = torch.tensor([1.0, 0.0, 0.0, 0.0])
|
| 42 |
+
results = idx.search(query, top_k=2)
|
| 43 |
+
assert results[0]["cache_id"] == "close"
|
| 44 |
+
assert results[0]["similarity"] > results[1]["similarity"]
|
| 45 |
+
|
| 46 |
+
def test_search_empty_returns_empty(self) -> None:
|
| 47 |
+
idx = ManifoldIndex(dim=4)
|
| 48 |
+
results = idx.search(torch.randn(4), top_k=5)
|
| 49 |
+
assert results == []
|
| 50 |
+
|
| 51 |
+
def test_model_filter(self) -> None:
|
| 52 |
+
idx = ManifoldIndex(dim=4)
|
| 53 |
+
idx.add(torch.randn(4), _entry("a", model="llama"))
|
| 54 |
+
idx.add(torch.randn(4), _entry("b", model="phi"))
|
| 55 |
+
results = idx.search(torch.randn(4), top_k=10, model_id="phi")
|
| 56 |
+
assert all(r["model_id"] == "phi" for r in results)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class TestRemoveAndRebuild:
|
| 60 |
+
"""Remove entries and rebuild index."""
|
| 61 |
+
|
| 62 |
+
def test_remove_hides_from_search(self) -> None:
|
| 63 |
+
idx = ManifoldIndex(dim=4)
|
| 64 |
+
v = torch.tensor([1.0, 0.0, 0.0, 0.0])
|
| 65 |
+
idx.add(v, _entry("target"))
|
| 66 |
+
assert idx.remove("target")
|
| 67 |
+
results = idx.search(v, top_k=1)
|
| 68 |
+
assert len(results) == 0
|
| 69 |
+
|
| 70 |
+
def test_rebuild_compacts(self) -> None:
|
| 71 |
+
idx = ManifoldIndex(dim=4)
|
| 72 |
+
for i in range(5):
|
| 73 |
+
idx.add(torch.randn(4), _entry(f"c{i}"))
|
| 74 |
+
idx.remove("c1")
|
| 75 |
+
idx.remove("c3")
|
| 76 |
+
active = idx.rebuild()
|
| 77 |
+
assert active == 3
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class TestPersistence:
|
| 81 |
+
"""Save/load round-trip (D2: serialize_index/deserialize_index)."""
|
| 82 |
+
|
| 83 |
+
def test_save_load_round_trip(self, tmp_index_dir: Path) -> None:
|
| 84 |
+
idx = ManifoldIndex(dim=4)
|
| 85 |
+
v1 = torch.tensor([1.0, 0.0, 0.0, 0.0])
|
| 86 |
+
idx.add(v1, _entry("persisted"))
|
| 87 |
+
idx.save(tmp_index_dir / "test.faiss")
|
| 88 |
+
|
| 89 |
+
idx2 = ManifoldIndex(dim=4, index_path=tmp_index_dir / "test.faiss")
|
| 90 |
+
assert idx2.n_entries == 1
|
| 91 |
+
results = idx2.search(v1, top_k=1)
|
| 92 |
+
assert results[0]["cache_id"] == "persisted"
|
| 93 |
+
|
| 94 |
+
def test_dim_mismatch_raises(self) -> None:
|
| 95 |
+
idx = ManifoldIndex(dim=4)
|
| 96 |
+
with pytest.raises(ValueError, match="dim"):
|
| 97 |
+
idx.add(torch.randn(8), _entry("wrong"))
|
tests/test_retriever.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — Retriever Tests
|
| 3 |
+
Tests for EGRRetriever: store → index → query → retrieve pipeline.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from kvcos.core.cache_spec import LLAMA_3_1_8B
|
| 13 |
+
from kvcos.core.serializer import EngramSerializer
|
| 14 |
+
from kvcos.core.types import CompressionMethod, StateExtractionMode
|
| 15 |
+
from kvcos.core.manifold_index import ManifoldIndex
|
| 16 |
+
from kvcos.core.retriever import EGRRetriever, RetrievalResponse
|
| 17 |
+
from kvcos.core.state_extractor import MARStateExtractor
|
| 18 |
+
from kvcos.storage.local import LocalStorageBackend
|
| 19 |
+
from tests.conftest import make_synthetic_kv
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _build_retriever(
|
| 23 |
+
data_dir: Path, mode: StateExtractionMode = StateExtractionMode.MEAN_POOL,
|
| 24 |
+
) -> EGRRetriever:
|
| 25 |
+
ext = MARStateExtractor(mode=mode, rank=128)
|
| 26 |
+
dim = ext.output_dim(LLAMA_3_1_8B)
|
| 27 |
+
idx = ManifoldIndex(dim=dim)
|
| 28 |
+
storage = LocalStorageBackend(data_dir=data_dir)
|
| 29 |
+
return EGRRetriever(ext, idx, storage)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TestIndexAndRetrieve:
|
| 33 |
+
"""Full store → search → load pipeline."""
|
| 34 |
+
|
| 35 |
+
def test_index_returns_cache_id(self, tmp_data_dir: Path) -> None:
|
| 36 |
+
keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 37 |
+
retriever = _build_retriever(tmp_data_dir)
|
| 38 |
+
|
| 39 |
+
cid = retriever.index_engram(
|
| 40 |
+
keys=keys, values=values, spec=LLAMA_3_1_8B,
|
| 41 |
+
agent_id="test", task_description="test engram",
|
| 42 |
+
model_id=LLAMA_3_1_8B["model_id"],
|
| 43 |
+
output_dir=tmp_data_dir,
|
| 44 |
+
)
|
| 45 |
+
assert isinstance(cid, str)
|
| 46 |
+
assert len(cid) > 0
|
| 47 |
+
|
| 48 |
+
def test_retrieve_finds_stored(self, tmp_data_dir: Path) -> None:
|
| 49 |
+
keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 50 |
+
retriever = _build_retriever(tmp_data_dir)
|
| 51 |
+
|
| 52 |
+
retriever.index_engram(
|
| 53 |
+
keys=keys, values=values, spec=LLAMA_3_1_8B,
|
| 54 |
+
agent_id="test", task_description="findable engram",
|
| 55 |
+
model_id=LLAMA_3_1_8B["model_id"],
|
| 56 |
+
output_dir=tmp_data_dir,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64, seed=99)
|
| 60 |
+
response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=1)
|
| 61 |
+
|
| 62 |
+
assert isinstance(response, RetrievalResponse)
|
| 63 |
+
assert len(response.results) == 1
|
| 64 |
+
assert response.results[0].keys.shape == keys.shape
|
| 65 |
+
|
| 66 |
+
def test_retrieve_empty_index(self, tmp_data_dir: Path) -> None:
|
| 67 |
+
retriever = _build_retriever(tmp_data_dir)
|
| 68 |
+
query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 69 |
+
response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=5)
|
| 70 |
+
assert len(response.results) == 0
|
| 71 |
+
|
| 72 |
+
def test_delete_removes(self, tmp_data_dir: Path) -> None:
|
| 73 |
+
keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 74 |
+
retriever = _build_retriever(tmp_data_dir)
|
| 75 |
+
|
| 76 |
+
cid = retriever.index_engram(
|
| 77 |
+
keys=keys, values=values, spec=LLAMA_3_1_8B,
|
| 78 |
+
agent_id="test", task_description="deletable",
|
| 79 |
+
model_id=LLAMA_3_1_8B["model_id"],
|
| 80 |
+
output_dir=tmp_data_dir,
|
| 81 |
+
)
|
| 82 |
+
assert retriever.delete_engram(cid)
|
| 83 |
+
|
| 84 |
+
query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 85 |
+
response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=5)
|
| 86 |
+
assert len(response.results) == 0
|
tests/test_serializer.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — Serializer Tests
|
| 3 |
+
Tests for .eng safetensors serialize/deserialize round-trip (D7).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import pytest
|
| 11 |
+
import torch
|
| 12 |
+
from safetensors.torch import load_file
|
| 13 |
+
|
| 14 |
+
from kvcos.core.serializer import EngramSerializer, SerializationError
|
| 15 |
+
from kvcos.core.types import CompressionMethod
|
| 16 |
+
from tests.conftest import make_synthetic_kv
|
| 17 |
+
from kvcos.core.cache_spec import LLAMA_3_1_8B
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TestSerializeRoundTrip:
|
| 21 |
+
"""Serialize → deserialize preserves shape, dtype, metadata."""
|
| 22 |
+
|
| 23 |
+
def test_round_trip_shape(self, tmp_data_dir: Path) -> None:
|
| 24 |
+
keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=256)
|
| 25 |
+
s = EngramSerializer()
|
| 26 |
+
eng = tmp_data_dir / "test.eng"
|
| 27 |
+
|
| 28 |
+
s.serialize(
|
| 29 |
+
keys=keys, values=values,
|
| 30 |
+
agent_id="test-agent", task_description="unit test",
|
| 31 |
+
model_id=LLAMA_3_1_8B["model_id"], output_path=eng,
|
| 32 |
+
compression=CompressionMethod.FP16,
|
| 33 |
+
)
|
| 34 |
+
k_out, v_out, meta = s.deserialize(eng)
|
| 35 |
+
|
| 36 |
+
assert k_out.shape == keys.shape
|
| 37 |
+
assert v_out.shape == values.shape
|
| 38 |
+
|
| 39 |
+
def test_metadata_fields(self, tmp_data_dir: Path) -> None:
|
| 40 |
+
keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 41 |
+
s = EngramSerializer()
|
| 42 |
+
eng = tmp_data_dir / "meta.eng"
|
| 43 |
+
|
| 44 |
+
s.serialize(
|
| 45 |
+
keys=keys, values=values,
|
| 46 |
+
agent_id="agent-42", task_description="metadata check",
|
| 47 |
+
model_id=LLAMA_3_1_8B["model_id"], output_path=eng,
|
| 48 |
+
compression=CompressionMethod.Q8_0,
|
| 49 |
+
)
|
| 50 |
+
_, _, meta = s.deserialize(eng)
|
| 51 |
+
|
| 52 |
+
assert meta["agent_id"] == "agent-42"
|
| 53 |
+
assert meta["task_description"] == "metadata check"
|
| 54 |
+
assert meta["compression"] == "q8_0"
|
| 55 |
+
assert meta["n_layers"] == "32"
|
| 56 |
+
assert meta["model_family"] == "llama"
|
| 57 |
+
|
| 58 |
+
def test_safetensors_loadable(self, tmp_data_dir: Path) -> None:
|
| 59 |
+
"""D7: File must be valid safetensors."""
|
| 60 |
+
keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 61 |
+
s = EngramSerializer()
|
| 62 |
+
eng = tmp_data_dir / "valid.eng"
|
| 63 |
+
|
| 64 |
+
s.serialize(
|
| 65 |
+
keys=keys, values=values,
|
| 66 |
+
agent_id="test", task_description="safetensors check",
|
| 67 |
+
model_id=LLAMA_3_1_8B["model_id"], output_path=eng,
|
| 68 |
+
compression=CompressionMethod.FP16,
|
| 69 |
+
)
|
| 70 |
+
tensors = load_file(str(eng))
|
| 71 |
+
assert "layer_0_keys" in tensors
|
| 72 |
+
assert "layer_0_values" in tensors
|
| 73 |
+
|
| 74 |
+
def test_result_dict(self, tmp_data_dir: Path) -> None:
|
| 75 |
+
keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 76 |
+
s = EngramSerializer()
|
| 77 |
+
eng = tmp_data_dir / "result.eng"
|
| 78 |
+
|
| 79 |
+
result = s.serialize(
|
| 80 |
+
keys=keys, values=values,
|
| 81 |
+
agent_id="test", task_description="result check",
|
| 82 |
+
model_id=LLAMA_3_1_8B["model_id"], output_path=eng,
|
| 83 |
+
)
|
| 84 |
+
assert "cache_id" in result
|
| 85 |
+
assert result["size_bytes"] > 0
|
| 86 |
+
assert result["n_layers"] == 32
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class TestSerializerErrors:
|
| 90 |
+
"""Edge cases and error handling."""
|
| 91 |
+
|
| 92 |
+
def test_shape_mismatch_raises(self, tmp_data_dir: Path) -> None:
|
| 93 |
+
keys = torch.randn(32, 8, 64, 128, dtype=torch.float16)
|
| 94 |
+
values = torch.randn(32, 8, 32, 128, dtype=torch.float16)
|
| 95 |
+
s = EngramSerializer()
|
| 96 |
+
|
| 97 |
+
with pytest.raises(SerializationError, match="mismatch"):
|
| 98 |
+
s.serialize(
|
| 99 |
+
keys=keys, values=values,
|
| 100 |
+
agent_id="t", task_description="t",
|
| 101 |
+
model_id="test", output_path=tmp_data_dir / "bad.eng",
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
def test_3d_tensor_raises(self, tmp_data_dir: Path) -> None:
|
| 105 |
+
keys = torch.randn(8, 64, 128, dtype=torch.float16)
|
| 106 |
+
s = EngramSerializer()
|
| 107 |
+
|
| 108 |
+
with pytest.raises(SerializationError, match="4D"):
|
| 109 |
+
s.serialize(
|
| 110 |
+
keys=keys, values=keys,
|
| 111 |
+
agent_id="t", task_description="t",
|
| 112 |
+
model_id="test", output_path=tmp_data_dir / "bad.eng",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def test_missing_file_raises(self, tmp_data_dir: Path) -> None:
|
| 116 |
+
s = EngramSerializer()
|
| 117 |
+
with pytest.raises(SerializationError, match="not found"):
|
| 118 |
+
s.deserialize(tmp_data_dir / "nonexistent.eng")
|
tests/test_state_extractor.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — State Extractor Tests
|
| 3 |
+
Tests for all 3 EGR extraction modes (D3).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from kvcos.core.cache_spec import LLAMA_3_1_8B, PHI_3_MINI
|
| 11 |
+
from kvcos.core.types import StateExtractionMode
|
| 12 |
+
from kvcos.core.state_extractor import MARStateExtractor
|
| 13 |
+
from tests.conftest import make_synthetic_kv
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestMeanPool:
|
| 17 |
+
"""mean_pool: mean over layers, heads, context → [head_dim]."""
|
| 18 |
+
|
| 19 |
+
def test_output_dim(self) -> None:
|
| 20 |
+
keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 21 |
+
ext = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
|
| 22 |
+
result = ext.extract(keys, LLAMA_3_1_8B)
|
| 23 |
+
assert result.state_vec.shape == (128,)
|
| 24 |
+
|
| 25 |
+
def test_output_dim_api(self) -> None:
|
| 26 |
+
ext = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
|
| 27 |
+
assert ext.output_dim(LLAMA_3_1_8B) == 128
|
| 28 |
+
|
| 29 |
+
def test_l2_norm_positive(self) -> None:
|
| 30 |
+
keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 31 |
+
ext = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
|
| 32 |
+
result = ext.extract(keys, LLAMA_3_1_8B)
|
| 33 |
+
assert result.l2_norm > 0
|
| 34 |
+
|
| 35 |
+
def test_deterministic(self) -> None:
|
| 36 |
+
keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 37 |
+
ext = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
|
| 38 |
+
r1 = ext.extract(keys, LLAMA_3_1_8B)
|
| 39 |
+
r2 = ext.extract(keys, LLAMA_3_1_8B)
|
| 40 |
+
assert torch.equal(r1.state_vec, r2.state_vec)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TestSVDProject:
|
| 44 |
+
"""svd_project: truncated SVD, rank-160 → [rank]."""
|
| 45 |
+
|
| 46 |
+
def test_output_dim(self) -> None:
|
| 47 |
+
keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 48 |
+
ext = MARStateExtractor(mode=StateExtractionMode.SVD_PROJECT, rank=160)
|
| 49 |
+
result = ext.extract(keys, LLAMA_3_1_8B)
|
| 50 |
+
assert result.state_vec.shape == (128,) # clamped to head_dim
|
| 51 |
+
|
| 52 |
+
def test_projection_stored(self) -> None:
|
| 53 |
+
keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 54 |
+
ext = MARStateExtractor(mode=StateExtractionMode.SVD_PROJECT, rank=160)
|
| 55 |
+
ext.extract(keys, LLAMA_3_1_8B)
|
| 56 |
+
proj = ext.last_projection
|
| 57 |
+
assert proj is not None
|
| 58 |
+
assert 0.0 < proj.explained_variance_ratio <= 1.0
|
| 59 |
+
|
| 60 |
+
def test_n_layers_used(self) -> None:
|
| 61 |
+
keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 62 |
+
ext = MARStateExtractor(mode=StateExtractionMode.SVD_PROJECT)
|
| 63 |
+
result = ext.extract(keys, LLAMA_3_1_8B)
|
| 64 |
+
assert result.n_layers_used == 24 # layers 8-31
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TestXKVProject:
|
| 68 |
+
"""xkv_project: grouped cross-layer SVD."""
|
| 69 |
+
|
| 70 |
+
def test_output_dim(self) -> None:
|
| 71 |
+
keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 72 |
+
ext = MARStateExtractor(mode=StateExtractionMode.XKV_PROJECT, rank=160)
|
| 73 |
+
result = ext.extract(keys, LLAMA_3_1_8B)
|
| 74 |
+
expected_dim = ext.output_dim(LLAMA_3_1_8B)
|
| 75 |
+
assert result.state_vec.shape == (expected_dim,)
|
| 76 |
+
|
| 77 |
+
def test_different_from_mean_pool(self) -> None:
|
| 78 |
+
keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
|
| 79 |
+
ext_mp = MARStateExtractor(mode=StateExtractionMode.MEAN_POOL)
|
| 80 |
+
ext_xkv = MARStateExtractor(mode=StateExtractionMode.XKV_PROJECT)
|
| 81 |
+
r_mp = ext_mp.extract(keys, LLAMA_3_1_8B)
|
| 82 |
+
r_xkv = ext_xkv.extract(keys, LLAMA_3_1_8B)
|
| 83 |
+
assert r_mp.state_vec.shape != r_xkv.state_vec.shape
|
| 84 |
+
|
| 85 |
+
def test_phi3_works(self) -> None:
|
| 86 |
+
keys, _ = make_synthetic_kv(PHI_3_MINI, ctx_len=64)
|
| 87 |
+
ext = MARStateExtractor(mode=StateExtractionMode.XKV_PROJECT, rank=96)
|
| 88 |
+
result = ext.extract(keys, PHI_3_MINI)
|
| 89 |
+
assert result.state_vec.dim() == 1
|
| 90 |
+
assert result.state_vec.shape[0] > 0
|