File size: 3,134 Bytes
2ece486 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | """
ENGRAM Protocol — Manifold Index Tests
Tests for FAISS IndexFlatIP add/search/remove/persist (D2, D4).
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
import pytest
import torch
from kvcos.core.manifold_index import IndexEntry, ManifoldIndex
def _entry(cid: str = "c1", model: str = "llama") -> IndexEntry:
return IndexEntry(
cache_id=cid, task_description="test",
model_id=model, created_at="2026-01-01T00:00:00Z",
context_len=256, l2_norm=1.0,
)
class TestAddAndSearch:
"""Add vectors, search via MIPS."""
def test_add_increments(self) -> None:
idx = ManifoldIndex(dim=8)
idx.add(torch.randn(8), _entry("a"))
idx.add(torch.randn(8), _entry("b"))
assert idx.n_entries == 2
def test_search_returns_correct_order(self) -> None:
idx = ManifoldIndex(dim=4)
v1 = torch.tensor([1.0, 0.0, 0.0, 0.0])
v2 = torch.tensor([0.0, 1.0, 0.0, 0.0])
idx.add(v1, _entry("close"))
idx.add(v2, _entry("far"))
query = torch.tensor([1.0, 0.0, 0.0, 0.0])
results = idx.search(query, top_k=2)
assert results[0]["cache_id"] == "close"
assert results[0]["similarity"] > results[1]["similarity"]
def test_search_empty_returns_empty(self) -> None:
idx = ManifoldIndex(dim=4)
results = idx.search(torch.randn(4), top_k=5)
assert results == []
def test_model_filter(self) -> None:
idx = ManifoldIndex(dim=4)
idx.add(torch.randn(4), _entry("a", model="llama"))
idx.add(torch.randn(4), _entry("b", model="phi"))
results = idx.search(torch.randn(4), top_k=10, model_id="phi")
assert all(r["model_id"] == "phi" for r in results)
class TestRemoveAndRebuild:
"""Remove entries and rebuild index."""
def test_remove_hides_from_search(self) -> None:
idx = ManifoldIndex(dim=4)
v = torch.tensor([1.0, 0.0, 0.0, 0.0])
idx.add(v, _entry("target"))
assert idx.remove("target")
results = idx.search(v, top_k=1)
assert len(results) == 0
def test_rebuild_compacts(self) -> None:
idx = ManifoldIndex(dim=4)
for i in range(5):
idx.add(torch.randn(4), _entry(f"c{i}"))
idx.remove("c1")
idx.remove("c3")
active = idx.rebuild()
assert active == 3
class TestPersistence:
"""Save/load round-trip (D2: serialize_index/deserialize_index)."""
def test_save_load_round_trip(self, tmp_index_dir: Path) -> None:
idx = ManifoldIndex(dim=4)
v1 = torch.tensor([1.0, 0.0, 0.0, 0.0])
idx.add(v1, _entry("persisted"))
idx.save(tmp_index_dir / "test.faiss")
idx2 = ManifoldIndex(dim=4, index_path=tmp_index_dir / "test.faiss")
assert idx2.n_entries == 1
results = idx2.search(v1, top_k=1)
assert results[0]["cache_id"] == "persisted"
def test_dim_mismatch_raises(self) -> None:
idx = ManifoldIndex(dim=4)
with pytest.raises(ValueError, match="dim"):
idx.add(torch.randn(8), _entry("wrong"))
|