engram / tests /test_hnsw_index.py
eigengram's picture
test: upload 220 tests
2ece486 verified
"""Tests for kvcos.engram.hnsw_index — HNSW nearest-neighbor index."""
import torch
import pytest
from kvcos.engram.hnsw_index import EngramIndex, HNSWResult
@pytest.fixture
def small_index():
"""Build a small 8-dim HNSW index with 5 documents."""
idx = EngramIndex(dim=8)
ids = [f"doc_{i}" for i in range(5)]
# deterministic orthogonal-ish vectors
vecs = torch.eye(5, 8)
idx.add_batch(ids, vecs)
return idx
class TestEngramIndexBuild:
def test_add_batch_len(self, small_index):
assert len(small_index) == 5
def test_add_batch_ids_stored(self, small_index):
assert small_index._ids == [f"doc_{i}" for i in range(5)]
def test_repr(self, small_index):
r = repr(small_index)
assert "n=5" in r
assert "dim=8" in r
class TestEngramIndexSearch:
def test_search_returns_results(self, small_index):
query = torch.eye(5, 8)[0] # matches doc_0
results = small_index.search(query, top_k=3)
assert len(results) == 3
assert results[0].doc_id == "doc_0"
def test_search_scores_descending(self, small_index):
query = torch.eye(5, 8)[2]
results = small_index.search(query, top_k=5)
scores = [r.score for r in results]
assert scores == sorted(scores, reverse=True)
def test_search_margin(self, small_index):
query = torch.eye(5, 8)[0]
results = small_index.search(query, top_k=3)
assert results[0].margin >= 0
def test_search_raises_before_build(self):
idx = EngramIndex(dim=8)
with pytest.raises(RuntimeError, match="not built"):
idx.search(torch.randn(8), top_k=1)
class TestEngramIndexGetVector:
def test_get_vector_returns_tensor(self, small_index):
vec = small_index.get_vector("doc_0")
assert vec is not None
assert isinstance(vec, torch.Tensor)
assert vec.shape == (8,)
def test_get_vector_none_for_missing(self, small_index):
vec = small_index.get_vector("nonexistent")
assert vec is None
def test_get_vector_reconstructs_normalized(self, small_index):
"""Vectors are L2-normalized on add, so reconstruction should be unit-length."""
vec = small_index.get_vector("doc_0")
norm = torch.norm(vec).item()
assert abs(norm - 1.0) < 0.01
def test_get_vector_matches_original_direction(self, small_index):
"""Reconstructed vector should point in the same direction as the original."""
original = torch.nn.functional.normalize(torch.eye(5, 8)[3:4], dim=-1)[0]
reconstructed = small_index.get_vector("doc_3")
cosine = torch.dot(original, reconstructed).item()
assert cosine > 0.99
class TestEngramIndexPersistence:
def test_save_and_load(self, small_index, tmp_path):
path = str(tmp_path / "test_hnsw")
small_index.save(path)
loaded = EngramIndex.load(path)
assert len(loaded) == 5
assert loaded._ids == small_index._ids
def test_loaded_search_matches_original(self, small_index, tmp_path):
path = str(tmp_path / "test_hnsw")
small_index.save(path)
loaded = EngramIndex.load(path)
query = torch.eye(5, 8)[1]
orig_results = small_index.search(query, top_k=3)
load_results = loaded.search(query, top_k=3)
assert [r.doc_id for r in orig_results] == [r.doc_id for r in load_results]
def test_loaded_get_vector(self, small_index, tmp_path):
path = str(tmp_path / "test_hnsw")
small_index.save(path)
loaded = EngramIndex.load(path)
vec = loaded.get_vector("doc_2")
assert vec is not None
assert vec.shape == (8,)