| """ |
| ENGRAM Protocol — Manifold Index Tests |
| Tests for FAISS IndexFlatIP add/search/remove/persist (D2, D4). |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import numpy as np |
| import pytest |
| import torch |
|
|
| from kvcos.core.manifold_index import IndexEntry, ManifoldIndex |
|
|
|
|
| def _entry(cid: str = "c1", model: str = "llama") -> IndexEntry: |
| return IndexEntry( |
| cache_id=cid, task_description="test", |
| model_id=model, created_at="2026-01-01T00:00:00Z", |
| context_len=256, l2_norm=1.0, |
| ) |
|
|
|
|
| class TestAddAndSearch: |
| """Add vectors, search via MIPS.""" |
|
|
| def test_add_increments(self) -> None: |
| idx = ManifoldIndex(dim=8) |
| idx.add(torch.randn(8), _entry("a")) |
| idx.add(torch.randn(8), _entry("b")) |
| assert idx.n_entries == 2 |
|
|
| def test_search_returns_correct_order(self) -> None: |
| idx = ManifoldIndex(dim=4) |
| v1 = torch.tensor([1.0, 0.0, 0.0, 0.0]) |
| v2 = torch.tensor([0.0, 1.0, 0.0, 0.0]) |
| idx.add(v1, _entry("close")) |
| idx.add(v2, _entry("far")) |
|
|
| query = torch.tensor([1.0, 0.0, 0.0, 0.0]) |
| results = idx.search(query, top_k=2) |
| assert results[0]["cache_id"] == "close" |
| assert results[0]["similarity"] > results[1]["similarity"] |
|
|
| def test_search_empty_returns_empty(self) -> None: |
| idx = ManifoldIndex(dim=4) |
| results = idx.search(torch.randn(4), top_k=5) |
| assert results == [] |
|
|
| def test_model_filter(self) -> None: |
| idx = ManifoldIndex(dim=4) |
| idx.add(torch.randn(4), _entry("a", model="llama")) |
| idx.add(torch.randn(4), _entry("b", model="phi")) |
| results = idx.search(torch.randn(4), top_k=10, model_id="phi") |
| assert all(r["model_id"] == "phi" for r in results) |
|
|
|
|
| class TestRemoveAndRebuild: |
| """Remove entries and rebuild index.""" |
|
|
| def test_remove_hides_from_search(self) -> None: |
| idx = ManifoldIndex(dim=4) |
| v = torch.tensor([1.0, 0.0, 0.0, 0.0]) |
| idx.add(v, _entry("target")) |
| assert idx.remove("target") |
| results = idx.search(v, top_k=1) |
| assert len(results) == 0 |
|
|
| def test_rebuild_compacts(self) -> None: |
| idx = ManifoldIndex(dim=4) |
| for i in range(5): |
| idx.add(torch.randn(4), _entry(f"c{i}")) |
| idx.remove("c1") |
| idx.remove("c3") |
| active = idx.rebuild() |
| assert active == 3 |
|
|
|
|
| class TestPersistence: |
| """Save/load round-trip (D2: serialize_index/deserialize_index).""" |
|
|
| def test_save_load_round_trip(self, tmp_index_dir: Path) -> None: |
| idx = ManifoldIndex(dim=4) |
| v1 = torch.tensor([1.0, 0.0, 0.0, 0.0]) |
| idx.add(v1, _entry("persisted")) |
| idx.save(tmp_index_dir / "test.faiss") |
|
|
| idx2 = ManifoldIndex(dim=4, index_path=tmp_index_dir / "test.faiss") |
| assert idx2.n_entries == 1 |
| results = idx2.search(v1, top_k=1) |
| assert results[0]["cache_id"] == "persisted" |
|
|
| def test_dim_mismatch_raises(self) -> None: |
| idx = ManifoldIndex(dim=4) |
| with pytest.raises(ValueError, match="dim"): |
| idx.add(torch.randn(8), _entry("wrong")) |
|
|