| """ |
| ENGRAM Protocol β Retriever Tests |
| Tests for EGRRetriever: store β index β query β retrieve pipeline. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import torch |
|
|
| from kvcos.core.cache_spec import LLAMA_3_1_8B |
| from kvcos.core.serializer import EngramSerializer |
| from kvcos.core.types import CompressionMethod, StateExtractionMode |
| from kvcos.core.manifold_index import ManifoldIndex |
| from kvcos.core.retriever import EGRRetriever, RetrievalResponse |
| from kvcos.core.state_extractor import MARStateExtractor |
| from kvcos.storage.local import LocalStorageBackend |
| from tests.conftest import make_synthetic_kv |
|
|
|
|
| def _build_retriever( |
| data_dir: Path, mode: StateExtractionMode = StateExtractionMode.MEAN_POOL, |
| ) -> EGRRetriever: |
| ext = MARStateExtractor(mode=mode, rank=128) |
| dim = ext.output_dim(LLAMA_3_1_8B) |
| idx = ManifoldIndex(dim=dim) |
| storage = LocalStorageBackend(data_dir=data_dir) |
| return EGRRetriever(ext, idx, storage) |
|
|
|
|
| class TestIndexAndRetrieve: |
| """Full store β search β load pipeline.""" |
|
|
| def test_index_returns_cache_id(self, tmp_data_dir: Path) -> None: |
| keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64) |
| retriever = _build_retriever(tmp_data_dir) |
|
|
| cid = retriever.index_engram( |
| keys=keys, values=values, spec=LLAMA_3_1_8B, |
| agent_id="test", task_description="test engram", |
| model_id=LLAMA_3_1_8B["model_id"], |
| output_dir=tmp_data_dir, |
| ) |
| assert isinstance(cid, str) |
| assert len(cid) > 0 |
|
|
| def test_retrieve_finds_stored(self, tmp_data_dir: Path) -> None: |
| keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64) |
| retriever = _build_retriever(tmp_data_dir) |
|
|
| retriever.index_engram( |
| keys=keys, values=values, spec=LLAMA_3_1_8B, |
| agent_id="test", task_description="findable engram", |
| model_id=LLAMA_3_1_8B["model_id"], |
| output_dir=tmp_data_dir, |
| ) |
|
|
| query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64, seed=99) |
| response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=1) |
|
|
| assert isinstance(response, RetrievalResponse) |
| assert len(response.results) == 1 |
| assert response.results[0].keys.shape == keys.shape |
|
|
| def test_retrieve_empty_index(self, tmp_data_dir: Path) -> None: |
| retriever = _build_retriever(tmp_data_dir) |
| query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64) |
| response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=5) |
| assert len(response.results) == 0 |
|
|
| def test_delete_removes(self, tmp_data_dir: Path) -> None: |
| keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64) |
| retriever = _build_retriever(tmp_data_dir) |
|
|
| cid = retriever.index_engram( |
| keys=keys, values=values, spec=LLAMA_3_1_8B, |
| agent_id="test", task_description="deletable", |
| model_id=LLAMA_3_1_8B["model_id"], |
| output_dir=tmp_data_dir, |
| ) |
| assert retriever.delete_engram(cid) |
|
|
| query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64) |
| response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=5) |
| assert len(response.results) == 0 |
|
|