File size: 3,266 Bytes
2ece486 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | """
ENGRAM Protocol — Retriever Tests
Tests for EGRRetriever: store → index → query → retrieve pipeline.
"""
from __future__ import annotations
from pathlib import Path
import torch
from kvcos.core.cache_spec import LLAMA_3_1_8B
from kvcos.core.serializer import EngramSerializer
from kvcos.core.types import CompressionMethod, StateExtractionMode
from kvcos.core.manifold_index import ManifoldIndex
from kvcos.core.retriever import EGRRetriever, RetrievalResponse
from kvcos.core.state_extractor import MARStateExtractor
from kvcos.storage.local import LocalStorageBackend
from tests.conftest import make_synthetic_kv
def _build_retriever(
data_dir: Path, mode: StateExtractionMode = StateExtractionMode.MEAN_POOL,
) -> EGRRetriever:
ext = MARStateExtractor(mode=mode, rank=128)
dim = ext.output_dim(LLAMA_3_1_8B)
idx = ManifoldIndex(dim=dim)
storage = LocalStorageBackend(data_dir=data_dir)
return EGRRetriever(ext, idx, storage)
class TestIndexAndRetrieve:
"""Full store → search → load pipeline."""
def test_index_returns_cache_id(self, tmp_data_dir: Path) -> None:
keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
retriever = _build_retriever(tmp_data_dir)
cid = retriever.index_engram(
keys=keys, values=values, spec=LLAMA_3_1_8B,
agent_id="test", task_description="test engram",
model_id=LLAMA_3_1_8B["model_id"],
output_dir=tmp_data_dir,
)
assert isinstance(cid, str)
assert len(cid) > 0
def test_retrieve_finds_stored(self, tmp_data_dir: Path) -> None:
keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
retriever = _build_retriever(tmp_data_dir)
retriever.index_engram(
keys=keys, values=values, spec=LLAMA_3_1_8B,
agent_id="test", task_description="findable engram",
model_id=LLAMA_3_1_8B["model_id"],
output_dir=tmp_data_dir,
)
query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64, seed=99)
response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=1)
assert isinstance(response, RetrievalResponse)
assert len(response.results) == 1
assert response.results[0].keys.shape == keys.shape
def test_retrieve_empty_index(self, tmp_data_dir: Path) -> None:
retriever = _build_retriever(tmp_data_dir)
query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=5)
assert len(response.results) == 0
def test_delete_removes(self, tmp_data_dir: Path) -> None:
keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
retriever = _build_retriever(tmp_data_dir)
cid = retriever.index_engram(
keys=keys, values=values, spec=LLAMA_3_1_8B,
agent_id="test", task_description="deletable",
model_id=LLAMA_3_1_8B["model_id"],
output_dir=tmp_data_dir,
)
assert retriever.delete_engram(cid)
query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=5)
assert len(response.results) == 0
|