Spaces:
Running
Running
| """Comprehensive unit tests for CompressionStore. | |
| Tests cover: | |
| 1. CompressionStore class initialization | |
| 2. Storing compressed content with hash generation | |
| 3. Retrieving content by hash | |
| 4. TTL expiration behavior | |
| 5. Memory limits and eviction | |
| 6. Statistics tracking | |
| 7. Edge cases (empty content, duplicate stores, etc.) | |
| 8. Thread safety | |
| 9. Feedback loop integration | |
| 10. Search functionality | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import threading | |
| import time | |
| from typing import Any | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| from headroom.cache.compression_store import ( | |
| CompressionEntry, | |
| CompressionStore, | |
| RetrievalEvent, | |
| get_compression_store, | |
| reset_compression_store, | |
| ) | |
| # ============================================================================= | |
| # Fixtures | |
| # ============================================================================= | |
| def reset_global_store(): | |
| """Reset global compression store before and after each test.""" | |
| reset_compression_store() | |
| yield | |
| reset_compression_store() | |
| def store() -> CompressionStore: | |
| """Create a fresh CompressionStore instance for testing.""" | |
| return CompressionStore() | |
| def store_with_short_ttl() -> CompressionStore: | |
| """Create a CompressionStore with 1 second TTL for expiration tests.""" | |
| return CompressionStore(default_ttl=1) | |
| def store_with_small_capacity() -> CompressionStore: | |
| """Create a CompressionStore with small capacity for eviction tests.""" | |
| return CompressionStore(max_entries=3) | |
| def sample_items() -> list[dict[str, Any]]: | |
| """Sample list of items for testing.""" | |
| return [{"id": i, "name": f"item_{i}", "value": i * 10} for i in range(100)] | |
| def sample_original(sample_items: list[dict[str, Any]]) -> str: | |
| """Sample original JSON content.""" | |
| return json.dumps(sample_items) | |
| def sample_compressed(sample_items: list[dict[str, Any]]) -> str: | |
| """Sample compressed JSON content (first 10 items).""" | |
| return json.dumps(sample_items[:10]) | |
| # ============================================================================= | |
| # CompressionEntry Tests | |
| # ============================================================================= | |
| class TestCompressionEntry: | |
| """Tests for CompressionEntry dataclass.""" | |
| def test_entry_creation_with_defaults(self): | |
| """CompressionEntry can be created with minimal required fields.""" | |
| entry = CompressionEntry( | |
| hash="abc123", | |
| original_content="[1,2,3]", | |
| compressed_content="[1]", | |
| original_tokens=100, | |
| compressed_tokens=10, | |
| original_item_count=3, | |
| compressed_item_count=1, | |
| tool_name=None, | |
| tool_call_id=None, | |
| query_context=None, | |
| created_at=time.time(), | |
| ) | |
| assert entry.hash == "abc123" | |
| assert entry.ttl == 300 # Default TTL | |
| assert entry.retrieval_count == 0 | |
| assert entry.search_queries == [] | |
| assert entry.last_accessed is None | |
| def test_entry_is_expired_false_when_fresh(self): | |
| """Fresh entries are not expired.""" | |
| entry = CompressionEntry( | |
| hash="abc123", | |
| original_content="[1]", | |
| compressed_content="[]", | |
| original_tokens=10, | |
| compressed_tokens=0, | |
| original_item_count=1, | |
| compressed_item_count=0, | |
| tool_name=None, | |
| tool_call_id=None, | |
| query_context=None, | |
| created_at=time.time(), | |
| ttl=300, | |
| ) | |
| assert entry.is_expired() is False | |
| def test_entry_is_expired_true_after_ttl(self): | |
| """Entries are expired after TTL passes.""" | |
| entry = CompressionEntry( | |
| hash="abc123", | |
| original_content="[1]", | |
| compressed_content="[]", | |
| original_tokens=10, | |
| compressed_tokens=0, | |
| original_item_count=1, | |
| compressed_item_count=0, | |
| tool_name=None, | |
| tool_call_id=None, | |
| query_context=None, | |
| created_at=time.time() - 10, # 10 seconds ago | |
| ttl=5, # 5 second TTL | |
| ) | |
| assert entry.is_expired() is True | |
| def test_record_access_increments_count(self): | |
| """record_access increments retrieval_count.""" | |
| entry = CompressionEntry( | |
| hash="abc123", | |
| original_content="[1]", | |
| compressed_content="[]", | |
| original_tokens=10, | |
| compressed_tokens=0, | |
| original_item_count=1, | |
| compressed_item_count=0, | |
| tool_name=None, | |
| tool_call_id=None, | |
| query_context=None, | |
| created_at=time.time(), | |
| ) | |
| assert entry.retrieval_count == 0 | |
| entry.record_access() | |
| assert entry.retrieval_count == 1 | |
| entry.record_access() | |
| assert entry.retrieval_count == 2 | |
| def test_record_access_updates_last_accessed(self): | |
| """record_access updates last_accessed timestamp.""" | |
| entry = CompressionEntry( | |
| hash="abc123", | |
| original_content="[1]", | |
| compressed_content="[]", | |
| original_tokens=10, | |
| compressed_tokens=0, | |
| original_item_count=1, | |
| compressed_item_count=0, | |
| tool_name=None, | |
| tool_call_id=None, | |
| query_context=None, | |
| created_at=time.time(), | |
| ) | |
| assert entry.last_accessed is None | |
| before = time.time() | |
| entry.record_access() | |
| after = time.time() | |
| assert entry.last_accessed is not None | |
| assert before <= entry.last_accessed <= after | |
| def test_record_access_tracks_unique_queries(self): | |
| """record_access tracks unique search queries.""" | |
| entry = CompressionEntry( | |
| hash="abc123", | |
| original_content="[1]", | |
| compressed_content="[]", | |
| original_tokens=10, | |
| compressed_tokens=0, | |
| original_item_count=1, | |
| compressed_item_count=0, | |
| tool_name=None, | |
| tool_call_id=None, | |
| query_context=None, | |
| created_at=time.time(), | |
| ) | |
| entry.record_access(query="query1") | |
| entry.record_access(query="query2") | |
| entry.record_access(query="query1") # Duplicate | |
| assert "query1" in entry.search_queries | |
| assert "query2" in entry.search_queries | |
| assert len(entry.search_queries) == 2 # No duplicates | |
| def test_record_access_limits_queries_to_10(self): | |
| """record_access keeps only last 10 queries.""" | |
| entry = CompressionEntry( | |
| hash="abc123", | |
| original_content="[1]", | |
| compressed_content="[]", | |
| original_tokens=10, | |
| compressed_tokens=0, | |
| original_item_count=1, | |
| compressed_item_count=0, | |
| tool_name=None, | |
| tool_call_id=None, | |
| query_context=None, | |
| created_at=time.time(), | |
| ) | |
| for i in range(15): | |
| entry.record_access(query=f"query_{i}") | |
| assert len(entry.search_queries) == 10 | |
| # Should have the last 10 queries | |
| assert "query_5" in entry.search_queries | |
| assert "query_14" in entry.search_queries | |
| assert "query_0" not in entry.search_queries | |
| def test_record_access_ignores_none_query(self): | |
| """record_access does not add None queries to list.""" | |
| entry = CompressionEntry( | |
| hash="abc123", | |
| original_content="[1]", | |
| compressed_content="[]", | |
| original_tokens=10, | |
| compressed_tokens=0, | |
| original_item_count=1, | |
| compressed_item_count=0, | |
| tool_name=None, | |
| tool_call_id=None, | |
| query_context=None, | |
| created_at=time.time(), | |
| ) | |
| entry.record_access(query=None) | |
| entry.record_access() | |
| assert len(entry.search_queries) == 0 | |
| # ============================================================================= | |
| # CompressionStore Initialization Tests | |
| # ============================================================================= | |
| class TestCompressionStoreInit: | |
| """Tests for CompressionStore initialization.""" | |
| def test_default_initialization(self): | |
| """CompressionStore initializes with default values.""" | |
| store = CompressionStore() | |
| assert store._max_entries == 1000 | |
| assert store._default_ttl == 300 | |
| assert store._enable_feedback is True | |
| assert store._backend is not None | |
| def test_custom_max_entries(self): | |
| """CompressionStore accepts custom max_entries.""" | |
| store = CompressionStore(max_entries=500) | |
| assert store._max_entries == 500 | |
| def test_custom_default_ttl(self): | |
| """CompressionStore accepts custom default_ttl.""" | |
| store = CompressionStore(default_ttl=600) | |
| assert store._default_ttl == 600 | |
| def test_feedback_can_be_disabled(self): | |
| """CompressionStore can disable feedback tracking.""" | |
| store = CompressionStore(enable_feedback=False) | |
| assert store._enable_feedback is False | |
| def test_custom_backend(self): | |
| """CompressionStore accepts custom backend.""" | |
| mock_backend = MagicMock() | |
| mock_backend.count.return_value = 0 | |
| mock_backend.get.return_value = None | |
| store = CompressionStore(backend=mock_backend) | |
| assert store._backend is mock_backend | |
| # ============================================================================= | |
| # Store Operations Tests | |
| # ============================================================================= | |
| class TestCompressionStoreOperations: | |
| """Tests for CompressionStore store operation.""" | |
| def test_store_returns_24_char_hash(self, store: CompressionStore): | |
| """store() returns a 24 character hash (96 bits for collision resistance).""" | |
| hash_key = store.store( | |
| original="[1,2,3]", | |
| compressed="[1]", | |
| ) | |
| assert len(hash_key) == 24 | |
| assert all(c in "0123456789abcdef" for c in hash_key) | |
| def test_store_hash_is_deterministic(self, store: CompressionStore): | |
| """Same content produces same hash.""" | |
| content = '{"id": 1, "name": "test"}' | |
| hash1 = store.store(original=content, compressed="{}") | |
| hash2 = store.store(original=content, compressed="{}") | |
| assert hash1 == hash2 | |
| def test_store_hash_based_on_original_content(self, store: CompressionStore): | |
| """Hash is computed from original content, not compressed.""" | |
| original = '{"id": 1}' | |
| compressed1 = '{"id": 1}' | |
| compressed2 = "{}" | |
| hash1 = store.store(original=original, compressed=compressed1) | |
| hash2 = store.store(original=original, compressed=compressed2) | |
| assert hash1 == hash2 # Same original = same hash | |
| def test_store_different_content_different_hash(self, store: CompressionStore): | |
| """Different content produces different hash.""" | |
| hash1 = store.store(original='{"id": 1}', compressed="{}") | |
| hash2 = store.store(original='{"id": 2}', compressed="{}") | |
| assert hash1 != hash2 | |
| def test_store_preserves_all_metadata( | |
| self, store: CompressionStore, sample_original: str, sample_compressed: str | |
| ): | |
| """store() preserves all metadata in the entry.""" | |
| hash_key = store.store( | |
| original=sample_original, | |
| compressed=sample_compressed, | |
| original_tokens=1000, | |
| compressed_tokens=100, | |
| original_item_count=100, | |
| compressed_item_count=10, | |
| tool_name="search_api", | |
| tool_call_id="call_123", | |
| query_context="user query", | |
| tool_signature_hash="sig_hash_123", | |
| compression_strategy="top_k", | |
| ttl=600, | |
| ) | |
| entry = store.retrieve(hash_key) | |
| assert entry is not None | |
| assert entry.original_content == sample_original | |
| assert entry.compressed_content == sample_compressed | |
| assert entry.original_tokens == 1000 | |
| assert entry.compressed_tokens == 100 | |
| assert entry.original_item_count == 100 | |
| assert entry.compressed_item_count == 10 | |
| assert entry.tool_name == "search_api" | |
| assert entry.tool_call_id == "call_123" | |
| assert entry.query_context == "user query" | |
| assert entry.tool_signature_hash == "sig_hash_123" | |
| assert entry.compression_strategy == "top_k" | |
| assert entry.ttl == 600 | |
| def test_store_uses_default_ttl(self, store: CompressionStore): | |
| """store() uses default TTL when not specified.""" | |
| hash_key = store.store(original="[1]", compressed="[]") | |
| entry = store.retrieve(hash_key) | |
| assert entry is not None | |
| assert entry.ttl == 300 # Default TTL | |
| def test_store_accepts_custom_ttl(self, store: CompressionStore): | |
| """store() accepts custom TTL override.""" | |
| hash_key = store.store(original="[1]", compressed="[]", ttl=60) | |
| entry = store.retrieve(hash_key) | |
| assert entry is not None | |
| assert entry.ttl == 60 | |
| # ============================================================================= | |
| # Retrieve Operations Tests | |
| # ============================================================================= | |
| class TestCompressionStoreRetrieve: | |
| """Tests for CompressionStore retrieve operation.""" | |
| def test_retrieve_existing_entry(self, store: CompressionStore): | |
| """retrieve() returns entry for existing hash.""" | |
| hash_key = store.store(original='{"id": 1}', compressed="{}") | |
| entry = store.retrieve(hash_key) | |
| assert entry is not None | |
| assert entry.hash == hash_key | |
| assert entry.original_content == '{"id": 1}' | |
| def test_retrieve_nonexistent_returns_none(self, store: CompressionStore): | |
| """retrieve() returns None for nonexistent hash.""" | |
| entry = store.retrieve("nonexistent_hash_key") | |
| assert entry is None | |
| def test_retrieve_expired_entry_returns_none(self, store_with_short_ttl: CompressionStore): | |
| """retrieve() returns None for expired entry.""" | |
| hash_key = store_with_short_ttl.store(original="[1]", compressed="[]") | |
| # Should exist immediately | |
| assert store_with_short_ttl.retrieve(hash_key) is not None | |
| # Wait for expiration | |
| time.sleep(1.1) | |
| # Should be None after expiration | |
| assert store_with_short_ttl.retrieve(hash_key) is None | |
| def test_retrieve_increments_access_count(self, store: CompressionStore): | |
| """retrieve() increments entry access count.""" | |
| hash_key = store.store(original="[1]", compressed="[]") | |
| store.retrieve(hash_key) | |
| store.retrieve(hash_key) | |
| entry = store.retrieve(hash_key) | |
| assert entry is not None | |
| assert entry.retrieval_count >= 3 | |
| def test_retrieve_with_query_tracks_query(self, store: CompressionStore): | |
| """retrieve() with query parameter tracks the query.""" | |
| hash_key = store.store(original="[1]", compressed="[]") | |
| store.retrieve(hash_key, query="test query") | |
| entry = store.retrieve(hash_key) | |
| assert entry is not None | |
| assert "test query" in entry.search_queries | |
| def test_retrieve_returns_copy_not_reference(self, store: CompressionStore): | |
| """retrieve() returns a copy to prevent race conditions.""" | |
| hash_key = store.store(original="[1]", compressed="[]") | |
| entry1 = store.retrieve(hash_key) | |
| entry2 = store.retrieve(hash_key) | |
| # Modify the returned entry's mutable field | |
| assert entry1 is not None | |
| assert entry2 is not None | |
| entry1.search_queries.append("modified") | |
| # Should not affect the other entry | |
| assert "modified" not in entry2.search_queries | |
| # ============================================================================= | |
| # TTL Expiration Tests | |
| # ============================================================================= | |
| class TestCompressionStoreTTL: | |
| """Tests for CompressionStore TTL expiration behavior.""" | |
| def test_entry_exists_before_ttl(self, store_with_short_ttl: CompressionStore): | |
| """Entry exists before TTL expires.""" | |
| hash_key = store_with_short_ttl.store(original="[1]", compressed="[]") | |
| assert store_with_short_ttl.exists(hash_key) is True | |
| def test_entry_not_exists_after_ttl(self, store_with_short_ttl: CompressionStore): | |
| """Entry does not exist after TTL expires.""" | |
| hash_key = store_with_short_ttl.store(original="[1]", compressed="[]") | |
| time.sleep(1.1) | |
| assert store_with_short_ttl.exists(hash_key) is False | |
| def test_get_metadata_returns_none_for_expired(self, store_with_short_ttl: CompressionStore): | |
| """get_metadata returns None for expired entries.""" | |
| hash_key = store_with_short_ttl.store(original="[1]", compressed="[]") | |
| time.sleep(1.1) | |
| assert store_with_short_ttl.get_metadata(hash_key) is None | |
| def test_search_returns_empty_for_expired(self, store_with_short_ttl: CompressionStore): | |
| """search returns empty list for expired entries.""" | |
| hash_key = store_with_short_ttl.store( | |
| original=json.dumps([{"id": 1, "name": "test"}]), | |
| compressed="[]", | |
| ) | |
| time.sleep(1.1) | |
| results = store_with_short_ttl.search(hash_key, "test") | |
| assert results == [] | |
| def test_exists_clean_expired_false_does_not_delete( | |
| self, store_with_short_ttl: CompressionStore | |
| ): | |
| """exists() with clean_expired=False does not delete expired entry.""" | |
| hash_key = store_with_short_ttl.store(original="[1]", compressed="[]") | |
| time.sleep(1.1) | |
| # Check exists without cleaning | |
| result = store_with_short_ttl.exists(hash_key, clean_expired=False) | |
| assert result is False | |
| # Entry should still be in backend (not cleaned yet) | |
| # This is internal behavior - the entry is there but marked expired | |
| def test_exists_clean_expired_true_deletes(self, store_with_short_ttl: CompressionStore): | |
| """exists() with clean_expired=True deletes expired entry.""" | |
| hash_key = store_with_short_ttl.store(original="[1]", compressed="[]") | |
| time.sleep(1.1) | |
| # Check exists with cleaning | |
| result = store_with_short_ttl.exists(hash_key, clean_expired=True) | |
| assert result is False | |
| # ============================================================================= | |
| # Eviction Tests | |
| # ============================================================================= | |
| class TestCompressionStoreEviction: | |
| """Tests for CompressionStore memory limits and eviction.""" | |
| def test_eviction_at_capacity(self, store_with_small_capacity: CompressionStore): | |
| """Oldest entries are evicted when at capacity.""" | |
| hashes = [] | |
| for i in range(5): | |
| h = store_with_small_capacity.store( | |
| original=f"content_{i}", | |
| compressed=f"compressed_{i}", | |
| ) | |
| hashes.append(h) | |
| time.sleep(0.01) # Ensure different timestamps | |
| # Only last 3 should exist (capacity is 3) | |
| assert not store_with_small_capacity.exists(hashes[0]) | |
| assert not store_with_small_capacity.exists(hashes[1]) | |
| assert store_with_small_capacity.exists(hashes[2]) | |
| assert store_with_small_capacity.exists(hashes[3]) | |
| assert store_with_small_capacity.exists(hashes[4]) | |
| def test_eviction_removes_oldest_first(self, store_with_small_capacity: CompressionStore): | |
| """Eviction removes oldest entries first (heap-based).""" | |
| # Fill to capacity | |
| hashes = [] | |
| for i in range(3): | |
| h = store_with_small_capacity.store( | |
| original=f"content_{i}", | |
| compressed=f"compressed_{i}", | |
| ) | |
| hashes.append(h) | |
| time.sleep(0.01) | |
| # All 3 should exist | |
| for h in hashes: | |
| assert store_with_small_capacity.exists(h) | |
| # Add one more - should evict oldest | |
| new_hash = store_with_small_capacity.store( | |
| original="content_new", | |
| compressed="compressed_new", | |
| ) | |
| # Oldest should be evicted | |
| assert not store_with_small_capacity.exists(hashes[0]) | |
| assert store_with_small_capacity.exists(hashes[1]) | |
| assert store_with_small_capacity.exists(hashes[2]) | |
| assert store_with_small_capacity.exists(new_hash) | |
| def test_eviction_cleans_expired_first(self): | |
| """Eviction cleans expired entries before evicting valid ones.""" | |
| store = CompressionStore(max_entries=3, default_ttl=1) | |
| # Add 2 entries that will expire | |
| hash1 = store.store(original="content_1", compressed="c1", ttl=1) | |
| hash2 = store.store(original="content_2", compressed="c2", ttl=1) | |
| time.sleep(1.1) # Wait for expiration | |
| # Add 2 more entries (should clean expired first, not evict new) | |
| hash3 = store.store(original="content_3", compressed="c3", ttl=300) | |
| hash4 = store.store(original="content_4", compressed="c4", ttl=300) | |
| # Expired entries should be gone | |
| assert not store.exists(hash1) | |
| assert not store.exists(hash2) | |
| # New entries should exist | |
| assert store.exists(hash3) | |
| assert store.exists(hash4) | |
| def test_heap_rebuild_on_stale_threshold(self): | |
| """Heap is rebuilt when stale entry ratio exceeds threshold.""" | |
| store = CompressionStore(max_entries=10) | |
| # Store entries and then replace them to create stale heap entries | |
| for i in range(5): | |
| store.store(original=f"content_{i}", compressed=f"c_{i}") | |
| # Replace all entries (creates stale heap entries) | |
| for i in range(5): | |
| store.store(original=f"content_{i}", compressed=f"updated_{i}") | |
| # Stale ratio should be tracked | |
| # The heap rebuild happens automatically when threshold is exceeded | |
| # ============================================================================= | |
| # Statistics Tests | |
| # ============================================================================= | |
| class TestCompressionStoreStats: | |
| """Tests for CompressionStore statistics tracking.""" | |
| def test_get_stats_entry_count(self, store: CompressionStore): | |
| """get_stats returns correct entry count.""" | |
| store.store(original="[1]", compressed="[]") | |
| store.store(original="[2]", compressed="[]") | |
| stats = store.get_stats() | |
| assert stats["entry_count"] == 2 | |
| def test_get_stats_max_entries(self, store: CompressionStore): | |
| """get_stats includes max_entries configuration.""" | |
| stats = store.get_stats() | |
| assert stats["max_entries"] == 1000 | |
| def test_get_stats_token_totals(self, store: CompressionStore): | |
| """get_stats calculates token totals correctly.""" | |
| store.store( | |
| original="[1]", | |
| compressed="[]", | |
| original_tokens=100, | |
| compressed_tokens=10, | |
| ) | |
| store.store( | |
| original="[2]", | |
| compressed="[]", | |
| original_tokens=200, | |
| compressed_tokens=20, | |
| ) | |
| stats = store.get_stats() | |
| assert stats["total_original_tokens"] == 300 | |
| assert stats["total_compressed_tokens"] == 30 | |
| def test_get_stats_retrieval_count(self, store: CompressionStore): | |
| """get_stats tracks total retrievals.""" | |
| hash_key = store.store(original="[1]", compressed="[]") | |
| store.retrieve(hash_key) | |
| store.retrieve(hash_key) | |
| stats = store.get_stats() | |
| assert stats["total_retrievals"] >= 2 | |
| def test_get_stats_event_count(self, store: CompressionStore): | |
| """get_stats includes retrieval event count.""" | |
| hash_key = store.store(original="[1]", compressed="[]") | |
| store.retrieve(hash_key) | |
| store.retrieve(hash_key) | |
| stats = store.get_stats() | |
| assert stats["event_count"] >= 2 | |
| def test_get_stats_includes_backend_stats(self, store: CompressionStore): | |
| """get_stats includes backend-specific stats.""" | |
| store.store(original="[1]", compressed="[]") | |
| stats = store.get_stats() | |
| assert "backend" in stats | |
| assert stats["backend"]["backend_type"] == "memory" | |
| # ============================================================================= | |
| # Get Metadata Tests | |
| # ============================================================================= | |
| class TestCompressionStoreMetadata: | |
| """Tests for CompressionStore get_metadata operation.""" | |
| def test_get_metadata_returns_dict(self, store: CompressionStore): | |
| """get_metadata returns dict with expected fields.""" | |
| hash_key = store.store( | |
| original="[1,2,3]", | |
| compressed="[1]", | |
| tool_name="test_tool", | |
| original_item_count=3, | |
| compressed_item_count=1, | |
| query_context="test query", | |
| ) | |
| metadata = store.get_metadata(hash_key) | |
| assert metadata is not None | |
| assert metadata["hash"] == hash_key | |
| assert metadata["tool_name"] == "test_tool" | |
| assert metadata["original_item_count"] == 3 | |
| assert metadata["compressed_item_count"] == 1 | |
| assert metadata["query_context"] == "test query" | |
| assert metadata["compressed_content"] == "[1]" | |
| assert "created_at" in metadata | |
| assert "ttl" in metadata | |
| def test_get_metadata_nonexistent_returns_none(self, store: CompressionStore): | |
| """get_metadata returns None for nonexistent entry.""" | |
| metadata = store.get_metadata("nonexistent") | |
| assert metadata is None | |
| # ============================================================================= | |
| # Search Tests | |
| # ============================================================================= | |
| class TestCompressionStoreSearch: | |
| """Tests for CompressionStore search functionality.""" | |
| def test_search_with_bm25_returns_matches(self, store: CompressionStore): | |
| """search() uses BM25 to find matching items.""" | |
| items = [ | |
| {"id": 1, "content": "Python programming language"}, | |
| {"id": 2, "content": "JavaScript web development"}, | |
| {"id": 3, "content": "Python data science pandas"}, | |
| {"id": 4, "content": "Java enterprise applications"}, | |
| {"id": 5, "content": "Python machine learning tensorflow"}, | |
| ] | |
| hash_key = store.store( | |
| original=json.dumps(items), | |
| compressed=json.dumps(items[:2]), | |
| ) | |
| results = store.search(hash_key, "Python programming") | |
| assert len(results) >= 1 | |
| result_ids = [r["id"] for r in results] | |
| assert 1 in result_ids # "Python programming language" should match | |
| def test_search_respects_max_results(self, store: CompressionStore): | |
| """search() respects max_results parameter.""" | |
| items = [{"id": i, "content": f"item {i}"} for i in range(50)] | |
| hash_key = store.store(original=json.dumps(items), compressed="[]") | |
| results = store.search(hash_key, "item", max_results=5) | |
| assert len(results) <= 5 | |
| def test_search_respects_score_threshold(self, store: CompressionStore): | |
| """search() filters by score threshold.""" | |
| items = [ | |
| {"id": 1, "content": "exact match query term"}, | |
| {"id": 2, "content": "completely unrelated content xyz"}, | |
| ] | |
| hash_key = store.store(original=json.dumps(items), compressed="[]") | |
| # High threshold should filter low-scoring items | |
| results = store.search(hash_key, "exact match query", score_threshold=0.5) | |
| # Should return the exact match, filter the unrelated | |
| if results: | |
| assert any("exact match" in str(r) for r in results) | |
| def test_search_nonexistent_returns_empty(self, store: CompressionStore): | |
| """search() returns empty list for nonexistent hash.""" | |
| results = store.search("nonexistent", "query") | |
| assert results == [] | |
| def test_search_invalid_json_returns_empty(self, store: CompressionStore): | |
| """search() handles invalid JSON gracefully.""" | |
| hash_key = store.store(original="not valid json", compressed="[]") | |
| results = store.search(hash_key, "query") | |
| assert results == [] | |
| def test_search_non_array_returns_empty(self, store: CompressionStore): | |
| """search() returns empty for non-array content.""" | |
| hash_key = store.store(original=json.dumps({"key": "value"}), compressed="{}") | |
| results = store.search(hash_key, "query") | |
| assert results == [] | |
| def test_search_empty_array_returns_empty(self, store: CompressionStore): | |
| """search() returns empty for empty array.""" | |
| hash_key = store.store(original="[]", compressed="[]") | |
| results = store.search(hash_key, "query") | |
| assert results == [] | |
| def test_search_logs_retrieval_event(self, store: CompressionStore): | |
| """search() logs retrieval event with search type.""" | |
| items = [{"id": 1, "content": "test"}] | |
| hash_key = store.store(original=json.dumps(items), compressed="[]") | |
| store.search(hash_key, "test query") | |
| events = store.get_retrieval_events() | |
| search_events = [e for e in events if e.retrieval_type == "search"] | |
| assert len(search_events) >= 1 | |
| assert search_events[-1].query == "test query" | |
| # ============================================================================= | |
| # Retrieval Events Tests | |
| # ============================================================================= | |
| class TestCompressionStoreRetrievalEvents: | |
| """Tests for CompressionStore retrieval event tracking.""" | |
| def test_retrieve_logs_full_event(self, store: CompressionStore): | |
| """retrieve() logs event with 'full' type.""" | |
| hash_key = store.store(original="[1]", compressed="[]", tool_name="test_tool") | |
| store.retrieve(hash_key) | |
| events = store.get_retrieval_events() | |
| full_events = [e for e in events if e.retrieval_type == "full"] | |
| assert len(full_events) >= 1 | |
| assert full_events[-1].tool_name == "test_tool" | |
| def test_get_retrieval_events_limit(self, store: CompressionStore): | |
| """get_retrieval_events respects limit parameter.""" | |
| hash_key = store.store(original="[1]", compressed="[]") | |
| for _ in range(10): | |
| store.retrieve(hash_key) | |
| events = store.get_retrieval_events(limit=3) | |
| assert len(events) <= 3 | |
| def test_get_retrieval_events_filter_by_tool(self, store: CompressionStore): | |
| """get_retrieval_events filters by tool_name.""" | |
| hash1 = store.store(original="[1]", compressed="[]", tool_name="tool_a") | |
| hash2 = store.store(original="[2]", compressed="[]", tool_name="tool_b") | |
| store.retrieve(hash1) | |
| store.retrieve(hash1) | |
| store.retrieve(hash2) | |
| tool_a_events = store.get_retrieval_events(tool_name="tool_a") | |
| tool_b_events = store.get_retrieval_events(tool_name="tool_b") | |
| assert len(tool_a_events) == 2 | |
| assert len(tool_b_events) == 1 | |
| def test_retrieval_events_include_tool_signature_hash(self, store: CompressionStore): | |
| """Retrieval events include tool_signature_hash for TOIN correlation.""" | |
| hash_key = store.store( | |
| original="[1]", | |
| compressed="[]", | |
| tool_signature_hash="sig_123", | |
| ) | |
| store.retrieve(hash_key) | |
| events = store.get_retrieval_events() | |
| assert len(events) >= 1 | |
| assert events[-1].tool_signature_hash == "sig_123" | |
| # ============================================================================= | |
| # Edge Cases Tests | |
| # ============================================================================= | |
| class TestCompressionStoreEdgeCases: | |
| """Tests for edge cases and error handling.""" | |
| def test_store_empty_content(self, store: CompressionStore): | |
| """store() handles empty content.""" | |
| hash_key = store.store(original="", compressed="") | |
| entry = store.retrieve(hash_key) | |
| assert entry is not None | |
| assert entry.original_content == "" | |
| def test_store_large_content(self, store: CompressionStore): | |
| """store() handles large content.""" | |
| large_content = json.dumps([{"id": i, "data": "x" * 1000} for i in range(100)]) | |
| hash_key = store.store(original=large_content, compressed="[]") | |
| entry = store.retrieve(hash_key) | |
| assert entry is not None | |
| assert len(entry.original_content) == len(large_content) | |
| def test_store_unicode_content(self, store: CompressionStore): | |
| """store() handles unicode content correctly.""" | |
| unicode_content = json.dumps([{"name": "cafe", "emoji": "hello"}]) | |
| hash_key = store.store(original=unicode_content, compressed="[]") | |
| entry = store.retrieve(hash_key) | |
| assert entry is not None | |
| assert "cafe" in entry.original_content | |
| def test_duplicate_store_updates_entry(self, store: CompressionStore): | |
| """Storing same content twice updates the entry.""" | |
| original = '{"id": 1}' | |
| hash1 = store.store(original=original, compressed="v1") | |
| hash2 = store.store(original=original, compressed="v2") | |
| assert hash1 == hash2 | |
| entry = store.retrieve(hash1) | |
| assert entry is not None | |
| # Second store should have updated the entry | |
| assert entry.compressed_content == "v2" | |
| def test_clear_removes_all_entries(self, store: CompressionStore): | |
| """clear() removes all entries.""" | |
| store.store(original="[1]", compressed="[]") | |
| store.store(original="[2]", compressed="[]") | |
| store.clear() | |
| stats = store.get_stats() | |
| assert stats["entry_count"] == 0 | |
| def test_clear_removes_retrieval_events(self, store: CompressionStore): | |
| """clear() removes retrieval events.""" | |
| hash_key = store.store(original="[1]", compressed="[]") | |
| store.retrieve(hash_key) | |
| store.clear() | |
| events = store.get_retrieval_events() | |
| assert len(events) == 0 | |
| # ============================================================================= | |
| # Thread Safety Tests | |
| # ============================================================================= | |
| class TestCompressionStoreThreadSafety: | |
| """Tests for thread safety.""" | |
| def test_concurrent_stores(self, store: CompressionStore): | |
| """Concurrent stores don't corrupt data.""" | |
| hashes: list[str] = [] | |
| lock = threading.Lock() | |
| errors: list[str] = [] | |
| def store_item(i: int) -> None: | |
| try: | |
| h = store.store( | |
| original=f"content_{i}", | |
| compressed=f"compressed_{i}", | |
| ) | |
| with lock: | |
| hashes.append(h) | |
| except Exception as e: | |
| with lock: | |
| errors.append(str(e)) | |
| threads = [threading.Thread(target=store_item, args=(i,)) for i in range(20)] | |
| for t in threads: | |
| t.start() | |
| for t in threads: | |
| t.join() | |
| assert errors == [] | |
| assert len(hashes) == 20 | |
| def test_concurrent_retrieves(self, store: CompressionStore): | |
| """Concurrent retrieves don't corrupt data.""" | |
| hash_key = store.store(original="[1,2,3]", compressed="[1]") | |
| errors: list[str] = [] | |
| results: list[CompressionEntry | None] = [] | |
| lock = threading.Lock() | |
| def retrieve_item() -> None: | |
| try: | |
| entry = store.retrieve(hash_key) | |
| with lock: | |
| results.append(entry) | |
| except Exception as e: | |
| with lock: | |
| errors.append(str(e)) | |
| threads = [threading.Thread(target=retrieve_item) for _ in range(20)] | |
| for t in threads: | |
| t.start() | |
| for t in threads: | |
| t.join() | |
| assert errors == [] | |
| assert len(results) == 20 | |
| for entry in results: | |
| assert entry is not None | |
| assert entry.original_content == "[1,2,3]" | |
| def test_concurrent_store_and_retrieve(self, store: CompressionStore): | |
| """Concurrent stores and retrieves don't corrupt data.""" | |
| errors: list[str] = [] | |
| def store_and_retrieve(i: int) -> None: | |
| try: | |
| items = [{"id": j, "batch": i} for j in range(10)] | |
| hash_key = store.store( | |
| original=json.dumps(items), | |
| compressed="[]", | |
| tool_name=f"tool_{i}", | |
| ) | |
| # Immediately retrieve | |
| entry = store.retrieve(hash_key) | |
| if entry is None: | |
| errors.append(f"Entry {i} not found after store") | |
| elif f'"batch": {i}' not in entry.original_content: | |
| errors.append(f"Entry {i} has wrong content") | |
| except Exception as e: | |
| errors.append(str(e)) | |
| threads = [threading.Thread(target=store_and_retrieve, args=(i,)) for i in range(20)] | |
| for t in threads: | |
| t.start() | |
| for t in threads: | |
| t.join() | |
| assert errors == [], f"Errors during concurrent operations: {errors}" | |
| # ============================================================================= | |
| # Global Store Singleton Tests | |
| # ============================================================================= | |
| class TestGlobalStore: | |
| """Tests for global store singleton pattern.""" | |
| def test_get_compression_store_returns_singleton(self): | |
| """get_compression_store returns same instance.""" | |
| store1 = get_compression_store() | |
| store2 = get_compression_store() | |
| assert store1 is store2 | |
| def test_reset_compression_store_clears_data(self): | |
| """reset_compression_store clears the global store.""" | |
| store = get_compression_store() | |
| store.store(original="[1]", compressed="[]") | |
| reset_compression_store() | |
| new_store = get_compression_store() | |
| stats = new_store.get_stats() | |
| assert stats["entry_count"] == 0 | |
| def test_get_compression_store_uses_params_only_on_first_call(self): | |
| """Parameters are only used on first initialization.""" | |
| reset_compression_store() | |
| store1 = get_compression_store(max_entries=500, default_ttl=600) | |
| assert store1._max_entries == 500 | |
| assert store1._default_ttl == 600 | |
| # Second call with different params should return same instance | |
| store2 = get_compression_store(max_entries=100, default_ttl=60) | |
| assert store2 is store1 | |
| assert store2._max_entries == 500 # Original value | |
| # ============================================================================= | |
| # Feedback Integration Tests | |
| # ============================================================================= | |
| class TestCompressionStoreFeedback: | |
| """Tests for feedback loop integration.""" | |
| def test_feedback_disabled_no_events(self): | |
| """No events logged when feedback is disabled.""" | |
| store = CompressionStore(enable_feedback=False) | |
| hash_key = store.store(original="[1]", compressed="[]") | |
| store.retrieve(hash_key) | |
| # Events should still be tracked internally for the store | |
| # but process_pending_feedback won't forward them | |
| # Verify events are tracked even with feedback disabled | |
| assert store.get_retrieval_events() is not None | |
| def test_feedback_enabled_logs_events(self): | |
| """Events logged when feedback is enabled.""" | |
| store = CompressionStore(enable_feedback=True) | |
| hash_key = store.store(original="[1]", compressed="[]", tool_name="test") | |
| store.retrieve(hash_key) | |
| events = store.get_retrieval_events() | |
| assert len(events) >= 1 | |
| def test_process_pending_feedback_forwards_events( | |
| self, mock_toin, mock_telemetry, mock_feedback | |
| ): | |
| """process_pending_feedback forwards events to feedback systems.""" | |
| mock_fb = MagicMock() | |
| mock_tel = MagicMock() | |
| mock_toin_instance = MagicMock() | |
| mock_feedback.return_value = mock_fb | |
| mock_telemetry.return_value = mock_tel | |
| mock_toin.return_value = mock_toin_instance | |
| store = CompressionStore(enable_feedback=True) | |
| hash_key = store.store( | |
| original="[1]", | |
| compressed="[]", | |
| tool_signature_hash="sig_123", | |
| compression_strategy="top_k", | |
| ) | |
| store.retrieve(hash_key) | |
| # Feedback should have been called | |
| assert mock_fb.record_retrieval.called | |
| def test_eviction_success_creates_event(self): | |
| """Eviction without retrieval creates success event.""" | |
| store = CompressionStore(max_entries=2, enable_feedback=True) | |
| # Store entries with signature hash for eviction tracking | |
| store.store( | |
| original="content_0", | |
| compressed="c0", | |
| tool_signature_hash="sig_0", | |
| compression_strategy="top_k", | |
| ) | |
| time.sleep(0.01) | |
| store.store( | |
| original="content_1", | |
| compressed="c1", | |
| tool_signature_hash="sig_1", | |
| compression_strategy="top_k", | |
| ) | |
| time.sleep(0.01) | |
| # This should trigger eviction of first entry | |
| store.store( | |
| original="content_2", | |
| compressed="c2", | |
| tool_signature_hash="sig_2", | |
| compression_strategy="top_k", | |
| ) | |
| # The evicted entry (content_0) was never retrieved, | |
| # so an eviction_success event should be queued | |
| # (tested via the pending_feedback mechanism) | |
| # ============================================================================= | |
| # RetrievalEvent Tests | |
| # ============================================================================= | |
| class TestRetrievalEvent: | |
| """Tests for RetrievalEvent dataclass.""" | |
| def test_retrieval_event_creation(self): | |
| """RetrievalEvent can be created with all fields.""" | |
| event = RetrievalEvent( | |
| hash="abc123", | |
| query="test query", | |
| items_retrieved=5, | |
| total_items=100, | |
| tool_name="search_api", | |
| timestamp=time.time(), | |
| retrieval_type="search", | |
| tool_signature_hash="sig_123", | |
| ) | |
| assert event.hash == "abc123" | |
| assert event.query == "test query" | |
| assert event.items_retrieved == 5 | |
| assert event.total_items == 100 | |
| assert event.tool_name == "search_api" | |
| assert event.retrieval_type == "search" | |
| assert event.tool_signature_hash == "sig_123" | |
| def test_retrieval_event_default_signature_hash(self): | |
| """RetrievalEvent has None default for tool_signature_hash.""" | |
| event = RetrievalEvent( | |
| hash="abc123", | |
| query=None, | |
| items_retrieved=10, | |
| total_items=10, | |
| tool_name="test", | |
| timestamp=time.time(), | |
| retrieval_type="full", | |
| ) | |
| assert event.tool_signature_hash is None | |
| # ============================================================================= | |
| # Hash Collision Detection Tests | |
| # ============================================================================= | |
| class TestHashCollisionDetection: | |
| """Tests for hash collision detection and handling.""" | |
| def test_same_content_no_collision_warning( | |
| self, store: CompressionStore, caplog: pytest.LogCaptureFixture | |
| ): | |
| """Same content stored twice should not warn about collision.""" | |
| import logging | |
| with caplog.at_level(logging.WARNING): | |
| store.store(original="[1,2,3]", compressed="[1]") | |
| store.store(original="[1,2,3]", compressed="[1,2]") | |
| # Should not have collision warning | |
| assert "Hash collision detected" not in caplog.text | |
| def test_hash_uses_md5_truncated(self, store: CompressionStore): | |
| """Hash is MD5 truncated to 24 characters (fast, non-crypto).""" | |
| content = "test content" | |
| expected_hash = hashlib.md5(content.encode()).hexdigest()[:24] # nosec B324 | |
| hash_key = store.store(original=content, compressed="[]") | |
| assert hash_key == expected_hash | |