File size: 3,266 Bytes
2ece486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
ENGRAM Protocol — Retriever Tests
Tests for EGRRetriever: store → index → query → retrieve pipeline.
"""

from __future__ import annotations

from pathlib import Path

import torch

from kvcos.core.cache_spec import LLAMA_3_1_8B
from kvcos.core.serializer import EngramSerializer
from kvcos.core.types import CompressionMethod, StateExtractionMode
from kvcos.core.manifold_index import ManifoldIndex
from kvcos.core.retriever import EGRRetriever, RetrievalResponse
from kvcos.core.state_extractor import MARStateExtractor
from kvcos.storage.local import LocalStorageBackend
from tests.conftest import make_synthetic_kv


def _build_retriever(
    data_dir: Path, mode: StateExtractionMode = StateExtractionMode.MEAN_POOL,
) -> EGRRetriever:
    ext = MARStateExtractor(mode=mode, rank=128)
    dim = ext.output_dim(LLAMA_3_1_8B)
    idx = ManifoldIndex(dim=dim)
    storage = LocalStorageBackend(data_dir=data_dir)
    return EGRRetriever(ext, idx, storage)


class TestIndexAndRetrieve:
    """Full store → search → load pipeline."""

    def test_index_returns_cache_id(self, tmp_data_dir: Path) -> None:
        keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        retriever = _build_retriever(tmp_data_dir)

        cid = retriever.index_engram(
            keys=keys, values=values, spec=LLAMA_3_1_8B,
            agent_id="test", task_description="test engram",
            model_id=LLAMA_3_1_8B["model_id"],
            output_dir=tmp_data_dir,
        )
        assert isinstance(cid, str)
        assert len(cid) > 0

    def test_retrieve_finds_stored(self, tmp_data_dir: Path) -> None:
        keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        retriever = _build_retriever(tmp_data_dir)

        retriever.index_engram(
            keys=keys, values=values, spec=LLAMA_3_1_8B,
            agent_id="test", task_description="findable engram",
            model_id=LLAMA_3_1_8B["model_id"],
            output_dir=tmp_data_dir,
        )

        query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64, seed=99)
        response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=1)

        assert isinstance(response, RetrievalResponse)
        assert len(response.results) == 1
        assert response.results[0].keys.shape == keys.shape

    def test_retrieve_empty_index(self, tmp_data_dir: Path) -> None:
        retriever = _build_retriever(tmp_data_dir)
        query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=5)
        assert len(response.results) == 0

    def test_delete_removes(self, tmp_data_dir: Path) -> None:
        keys, values = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        retriever = _build_retriever(tmp_data_dir)

        cid = retriever.index_engram(
            keys=keys, values=values, spec=LLAMA_3_1_8B,
            agent_id="test", task_description="deletable",
            model_id=LLAMA_3_1_8B["model_id"],
            output_dir=tmp_data_dir,
        )
        assert retriever.delete_engram(cid)

        query_keys, _ = make_synthetic_kv(LLAMA_3_1_8B, ctx_len=64)
        response = retriever.retrieve(query_keys, LLAMA_3_1_8B, top_k=5)
        assert len(response.results) == 0