File size: 3,721 Bytes
2ece486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""Tests for kvcos.engram.hnsw_index — HNSW nearest-neighbor index."""

import torch
import pytest

from kvcos.engram.hnsw_index import EngramIndex, HNSWResult


@pytest.fixture
def small_index():
    """Build a small 8-dim HNSW index with 5 documents."""
    idx = EngramIndex(dim=8)
    ids = [f"doc_{i}" for i in range(5)]
    # deterministic orthogonal-ish vectors
    vecs = torch.eye(5, 8)
    idx.add_batch(ids, vecs)
    return idx


class TestEngramIndexBuild:
    def test_add_batch_len(self, small_index):
        assert len(small_index) == 5

    def test_add_batch_ids_stored(self, small_index):
        assert small_index._ids == [f"doc_{i}" for i in range(5)]

    def test_repr(self, small_index):
        r = repr(small_index)
        assert "n=5" in r
        assert "dim=8" in r


class TestEngramIndexSearch:
    def test_search_returns_results(self, small_index):
        query = torch.eye(5, 8)[0]  # matches doc_0
        results = small_index.search(query, top_k=3)
        assert len(results) == 3
        assert results[0].doc_id == "doc_0"

    def test_search_scores_descending(self, small_index):
        query = torch.eye(5, 8)[2]
        results = small_index.search(query, top_k=5)
        scores = [r.score for r in results]
        assert scores == sorted(scores, reverse=True)

    def test_search_margin(self, small_index):
        query = torch.eye(5, 8)[0]
        results = small_index.search(query, top_k=3)
        assert results[0].margin >= 0

    def test_search_raises_before_build(self):
        idx = EngramIndex(dim=8)
        with pytest.raises(RuntimeError, match="not built"):
            idx.search(torch.randn(8), top_k=1)


class TestEngramIndexGetVector:
    def test_get_vector_returns_tensor(self, small_index):
        vec = small_index.get_vector("doc_0")
        assert vec is not None
        assert isinstance(vec, torch.Tensor)
        assert vec.shape == (8,)

    def test_get_vector_none_for_missing(self, small_index):
        vec = small_index.get_vector("nonexistent")
        assert vec is None

    def test_get_vector_reconstructs_normalized(self, small_index):
        """Vectors are L2-normalized on add, so reconstruction should be unit-length."""
        vec = small_index.get_vector("doc_0")
        norm = torch.norm(vec).item()
        assert abs(norm - 1.0) < 0.01

    def test_get_vector_matches_original_direction(self, small_index):
        """Reconstructed vector should point in the same direction as the original."""
        original = torch.nn.functional.normalize(torch.eye(5, 8)[3:4], dim=-1)[0]
        reconstructed = small_index.get_vector("doc_3")
        cosine = torch.dot(original, reconstructed).item()
        assert cosine > 0.99


class TestEngramIndexPersistence:
    def test_save_and_load(self, small_index, tmp_path):
        path = str(tmp_path / "test_hnsw")
        small_index.save(path)

        loaded = EngramIndex.load(path)
        assert len(loaded) == 5
        assert loaded._ids == small_index._ids

    def test_loaded_search_matches_original(self, small_index, tmp_path):
        path = str(tmp_path / "test_hnsw")
        small_index.save(path)
        loaded = EngramIndex.load(path)

        query = torch.eye(5, 8)[1]
        orig_results = small_index.search(query, top_k=3)
        load_results = loaded.search(query, top_k=3)
        assert [r.doc_id for r in orig_results] == [r.doc_id for r in load_results]

    def test_loaded_get_vector(self, small_index, tmp_path):
        path = str(tmp_path / "test_hnsw")
        small_index.save(path)
        loaded = EngramIndex.load(path)

        vec = loaded.get_vector("doc_2")
        assert vec is not None
        assert vec.shape == (8,)