| | """Tests for ARMS-HAT Python bindings.""" |
| |
|
| | import pytest |
| | import tempfile |
| | import os |
| |
|
| |
|
| | def test_import(): |
| | """Test that the module can be imported.""" |
| | from arms_hat import HatIndex, HatConfig, SearchResult |
| |
|
| |
|
| | def test_create_index(): |
| | """Test index creation.""" |
| | from arms_hat import HatIndex |
| |
|
| | index = HatIndex.cosine(128) |
| | assert len(index) == 0 |
| | assert index.is_empty() |
| |
|
| |
|
| | def test_add_and_query(): |
| | """Test adding points and querying.""" |
| | from arms_hat import HatIndex |
| |
|
| | dims = 64 |
| | index = HatIndex.cosine(dims) |
| |
|
| | |
| | ids = [] |
| | for i in range(10): |
| | embedding = [0.0] * dims |
| | embedding[i % dims] = 1.0 |
| | embedding[(i + 1) % dims] = 0.5 |
| | id_ = index.add(embedding) |
| | ids.append(id_) |
| | assert len(id_) == 32 |
| |
|
| | assert len(index) == 10 |
| | assert not index.is_empty() |
| |
|
| | |
| | query = [0.0] * dims |
| | query[0] = 1.0 |
| | query[1] = 0.5 |
| |
|
| | results = index.near(query, k=5) |
| | assert len(results) == 5 |
| |
|
| | |
| | assert results[0].id == ids[0] |
| | assert results[0].score > 0.9 |
| |
|
| |
|
| | def test_sessions(): |
| | """Test session management.""" |
| | from arms_hat import HatIndex |
| |
|
| | index = HatIndex.cosine(32) |
| |
|
| | |
| | for i in range(5): |
| | index.add([float(i % 32 == j) for j in range(32)]) |
| |
|
| | |
| | index.new_session() |
| |
|
| | |
| | for i in range(5): |
| | index.add([float((i + 10) % 32 == j) for j in range(32)]) |
| |
|
| | stats = index.stats() |
| | assert stats.session_count >= 1 |
| | assert stats.chunk_count == 10 |
| |
|
| |
|
| | def test_documents(): |
| | """Test document management within sessions.""" |
| | from arms_hat import HatIndex |
| |
|
| | index = HatIndex.cosine(32) |
| |
|
| | |
| | for i in range(3): |
| | index.add([1.0 if j == i else 0.0 for j in range(32)]) |
| |
|
| | |
| | index.new_document() |
| |
|
| | |
| | for i in range(3): |
| | index.add([1.0 if j == i + 10 else 0.0 for j in range(32)]) |
| |
|
| | stats = index.stats() |
| | assert stats.document_count >= 1 |
| | assert stats.chunk_count == 6 |
| |
|
| |
|
| | def test_persistence_bytes(): |
| | """Test serialization to/from bytes.""" |
| | from arms_hat import HatIndex |
| |
|
| | dims = 64 |
| | index = HatIndex.cosine(dims) |
| |
|
| | |
| | ids = [] |
| | for i in range(20): |
| | embedding = [0.1] * dims |
| | embedding[i % dims] = 1.0 |
| | ids.append(index.add(embedding)) |
| |
|
| | |
| | data = index.to_bytes() |
| | assert len(data) > 0 |
| |
|
| | |
| | loaded = HatIndex.from_bytes(data) |
| | assert len(loaded) == len(index) |
| |
|
| | |
| | query = [0.1] * dims |
| | query[0] = 1.0 |
| |
|
| | original_results = index.near(query, k=5) |
| | loaded_results = loaded.near(query, k=5) |
| |
|
| | assert len(original_results) == len(loaded_results) |
| | assert original_results[0].id == loaded_results[0].id |
| |
|
| |
|
| | def test_persistence_file(): |
| | """Test save/load to file.""" |
| | from arms_hat import HatIndex |
| |
|
| | dims = 64 |
| | index = HatIndex.cosine(dims) |
| |
|
| | |
| | for i in range(10): |
| | embedding = [0.1] * dims |
| | embedding[i % dims] = 1.0 |
| | index.add(embedding) |
| |
|
| | |
| | with tempfile.NamedTemporaryFile(suffix=".hat", delete=False) as f: |
| | path = f.name |
| |
|
| | try: |
| | index.save(path) |
| | assert os.path.exists(path) |
| | assert os.path.getsize(path) > 0 |
| |
|
| | |
| | loaded = HatIndex.load(path) |
| | assert len(loaded) == len(index) |
| |
|
| | finally: |
| | os.unlink(path) |
| |
|
| |
|
| | def test_config(): |
| | """Test custom configuration.""" |
| | from arms_hat import HatIndex, HatConfig |
| |
|
| | config = HatConfig() |
| | |
| | config = config.with_beam_width(5) |
| | config = config.with_temporal_weight(0.1) |
| |
|
| | index = HatIndex.with_config(128, config) |
| | assert len(index) == 0 |
| |
|
| |
|
| | def test_remove(): |
| | """Test point removal.""" |
| | from arms_hat import HatIndex |
| |
|
| | index = HatIndex.cosine(32) |
| |
|
| | id1 = index.add([1.0] + [0.0] * 31) |
| | id2 = index.add([0.0, 1.0] + [0.0] * 30) |
| |
|
| | assert len(index) == 2 |
| |
|
| | index.remove(id1) |
| | assert len(index) == 1 |
| |
|
| | |
| | results = index.near([0.0, 1.0] + [0.0] * 30, k=5) |
| | assert len(results) == 1 |
| | assert results[0].id == id2 |
| |
|
| |
|
| | def test_consolidate(): |
| | """Test consolidation.""" |
| | from arms_hat import HatIndex |
| |
|
| | index = HatIndex.cosine(32) |
| |
|
| | |
| | for i in range(100): |
| | embedding = [0.0] * 32 |
| | embedding[i % 32] = 1.0 |
| | index.add(embedding) |
| |
|
| | |
| | index.consolidate() |
| | index.consolidate_full() |
| |
|
| | assert len(index) == 100 |
| |
|
| |
|
| | def test_stats(): |
| | """Test stats retrieval.""" |
| | from arms_hat import HatIndex |
| |
|
| | index = HatIndex.cosine(64) |
| |
|
| | for i in range(10): |
| | index.add([float(i % 64 == j) for j in range(64)]) |
| |
|
| | stats = index.stats() |
| | assert stats.chunk_count == 10 |
| | assert stats.total_points == 10 |
| |
|
| |
|
| | def test_repr(): |
| | """Test string representations.""" |
| | from arms_hat import HatIndex, HatConfig, SearchResult |
| |
|
| | index = HatIndex.cosine(64) |
| | repr_str = repr(index) |
| | assert "HatIndex" in repr_str |
| |
|
| | config = HatConfig() |
| | repr_str = repr(config) |
| | assert "HatConfig" in repr_str |
| |
|
| |
|
| | def test_near_sessions(): |
| | """Test coarse-grained session search.""" |
| | from arms_hat import HatIndex |
| |
|
| | index = HatIndex.cosine(32) |
| |
|
| | |
| | for i in range(5): |
| | embedding = [0.0] * 32 |
| | embedding[0] = 1.0 |
| | embedding[i + 1] = 0.3 |
| | index.add(embedding) |
| |
|
| | index.new_session() |
| |
|
| | |
| | for i in range(5): |
| | embedding = [0.0] * 32 |
| | embedding[10] = 1.0 |
| | embedding[i + 11] = 0.3 |
| | index.add(embedding) |
| |
|
| | |
| | query = [0.0] * 32 |
| | query[0] = 1.0 |
| |
|
| | sessions = index.near_sessions(query, k=2) |
| | assert len(sessions) >= 1 |
| |
|
| | |
| | if len(sessions) > 1: |
| | assert sessions[0].score >= sessions[1].score |
| |
|
| |
|
| | def test_high_dimensions(): |
| | """Test with OpenAI embedding dimensions.""" |
| | from arms_hat import HatIndex |
| |
|
| | dims = 1536 |
| | index = HatIndex.cosine(dims) |
| |
|
| | |
| | for i in range(10): |
| | embedding = [(j * i * 0.01) % 1.0 for j in range(dims)] |
| | index.add(embedding) |
| |
|
| | assert len(index) == 10 |
| |
|
| | |
| | query = [0.5] * dims |
| | results = index.near(query, k=5) |
| | assert len(results) == 5 |
| |
|
| |
|
| | if __name__ == "__main__": |
| | pytest.main([__file__, "-v"]) |
| |
|