Spaces:
Running
Running
| """Tests for CompressionStore storage backends. | |
| These tests define the contract that all backends must fulfill. | |
| Each backend implementation should pass all these tests. | |
| """ | |
| from __future__ import annotations | |
| import threading | |
| import time | |
| from typing import TYPE_CHECKING | |
| import pytest | |
| from headroom.cache.backends import CompressionStoreBackend, InMemoryBackend | |
| from headroom.cache.compression_store import CompressionEntry | |
| if TYPE_CHECKING: | |
| from collections.abc import Callable | |
| def make_entry( | |
| hash_key: str = "test_hash", | |
| original: str = "original content", | |
| compressed: str = "compressed", | |
| original_tokens: int = 100, | |
| compressed_tokens: int = 10, | |
| ) -> CompressionEntry: | |
| """Create a test CompressionEntry.""" | |
| return CompressionEntry( | |
| hash=hash_key, | |
| original_content=original, | |
| compressed_content=compressed, | |
| original_tokens=original_tokens, | |
| compressed_tokens=compressed_tokens, | |
| original_item_count=5, | |
| compressed_item_count=2, | |
| tool_name="test_tool", | |
| tool_call_id="call_123", | |
| query_context="test query", | |
| created_at=time.time(), | |
| ttl=300, | |
| ) | |
| class TestCompressionStoreBackendProtocol: | |
| """Test that InMemoryBackend implements the protocol correctly.""" | |
| def test_inmemory_backend_implements_protocol(self) -> None: | |
| """InMemoryBackend should implement CompressionStoreBackend protocol.""" | |
| backend = InMemoryBackend() | |
| assert isinstance(backend, CompressionStoreBackend) | |
| def test_protocol_is_runtime_checkable(self) -> None: | |
| """Protocol should be runtime checkable.""" | |
| class NotABackend: | |
| pass | |
| assert not isinstance(NotABackend(), CompressionStoreBackend) | |
| class TestInMemoryBackend: | |
| """Test suite for InMemoryBackend. | |
| These tests define the contract for all backends. | |
| """ | |
| def backend(self) -> InMemoryBackend: | |
| """Create a fresh backend for each test.""" | |
| return InMemoryBackend() | |
| # --- Basic CRUD operations --- | |
| def test_get_returns_none_for_missing_key(self, backend: InMemoryBackend) -> None: | |
| """get() should return None for keys that don't exist.""" | |
| assert backend.get("nonexistent") is None | |
| def test_set_and_get_roundtrip(self, backend: InMemoryBackend) -> None: | |
| """set() followed by get() should return the same entry.""" | |
| entry = make_entry(hash_key="abc123") | |
| backend.set("abc123", entry) | |
| retrieved = backend.get("abc123") | |
| assert retrieved is not None | |
| assert retrieved.hash == "abc123" | |
| assert retrieved.original_content == "original content" | |
| assert retrieved.compressed_content == "compressed" | |
| assert retrieved.original_tokens == 100 | |
| assert retrieved.compressed_tokens == 10 | |
| def test_set_overwrites_existing(self, backend: InMemoryBackend) -> None: | |
| """set() should overwrite existing entries with the same key.""" | |
| entry1 = make_entry(hash_key="abc123", original="first") | |
| entry2 = make_entry(hash_key="abc123", original="second") | |
| backend.set("abc123", entry1) | |
| backend.set("abc123", entry2) | |
| retrieved = backend.get("abc123") | |
| assert retrieved is not None | |
| assert retrieved.original_content == "second" | |
| def test_delete_removes_entry(self, backend: InMemoryBackend) -> None: | |
| """delete() should remove the entry and return True.""" | |
| entry = make_entry(hash_key="abc123") | |
| backend.set("abc123", entry) | |
| result = backend.delete("abc123") | |
| assert result is True | |
| assert backend.get("abc123") is None | |
| def test_delete_returns_false_for_missing(self, backend: InMemoryBackend) -> None: | |
| """delete() should return False for keys that don't exist.""" | |
| result = backend.delete("nonexistent") | |
| assert result is False | |
| def test_exists_returns_true_for_stored_entry(self, backend: InMemoryBackend) -> None: | |
| """exists() should return True for stored entries.""" | |
| entry = make_entry(hash_key="abc123") | |
| backend.set("abc123", entry) | |
| assert backend.exists("abc123") is True | |
| def test_exists_returns_false_for_missing(self, backend: InMemoryBackend) -> None: | |
| """exists() should return False for missing entries.""" | |
| assert backend.exists("nonexistent") is False | |
| def test_clear_removes_all_entries(self, backend: InMemoryBackend) -> None: | |
| """clear() should remove all entries.""" | |
| backend.set("key1", make_entry(hash_key="key1")) | |
| backend.set("key2", make_entry(hash_key="key2")) | |
| backend.set("key3", make_entry(hash_key="key3")) | |
| backend.clear() | |
| assert backend.count() == 0 | |
| assert backend.get("key1") is None | |
| assert backend.get("key2") is None | |
| assert backend.get("key3") is None | |
| # --- Enumeration methods --- | |
| def test_count_returns_zero_for_empty(self, backend: InMemoryBackend) -> None: | |
| """count() should return 0 for empty backend.""" | |
| assert backend.count() == 0 | |
| def test_count_returns_correct_count(self, backend: InMemoryBackend) -> None: | |
| """count() should return the number of entries.""" | |
| backend.set("key1", make_entry(hash_key="key1")) | |
| backend.set("key2", make_entry(hash_key="key2")) | |
| backend.set("key3", make_entry(hash_key="key3")) | |
| assert backend.count() == 3 | |
| def test_keys_returns_empty_list_for_empty(self, backend: InMemoryBackend) -> None: | |
| """keys() should return empty list for empty backend.""" | |
| assert backend.keys() == [] | |
| def test_keys_returns_all_keys(self, backend: InMemoryBackend) -> None: | |
| """keys() should return all stored keys.""" | |
| backend.set("key1", make_entry(hash_key="key1")) | |
| backend.set("key2", make_entry(hash_key="key2")) | |
| backend.set("key3", make_entry(hash_key="key3")) | |
| keys = backend.keys() | |
| assert set(keys) == {"key1", "key2", "key3"} | |
| def test_items_returns_empty_list_for_empty(self, backend: InMemoryBackend) -> None: | |
| """items() should return empty list for empty backend.""" | |
| assert backend.items() == [] | |
| def test_items_returns_all_entries(self, backend: InMemoryBackend) -> None: | |
| """items() should return all (key, entry) pairs.""" | |
| entry1 = make_entry(hash_key="key1", original="content1") | |
| entry2 = make_entry(hash_key="key2", original="content2") | |
| backend.set("key1", entry1) | |
| backend.set("key2", entry2) | |
| items = backend.items() | |
| assert len(items) == 2 | |
| items_dict = dict(items) | |
| assert items_dict["key1"].original_content == "content1" | |
| assert items_dict["key2"].original_content == "content2" | |
| # --- Statistics --- | |
| def test_get_stats_returns_required_fields(self, backend: InMemoryBackend) -> None: | |
| """get_stats() should return required fields.""" | |
| stats = backend.get_stats() | |
| assert "backend_type" in stats | |
| assert "entry_count" in stats | |
| assert stats["backend_type"] == "memory" | |
| assert stats["entry_count"] == 0 | |
| def test_get_stats_entry_count_accurate(self, backend: InMemoryBackend) -> None: | |
| """get_stats() entry_count should match actual count.""" | |
| backend.set("key1", make_entry(hash_key="key1")) | |
| backend.set("key2", make_entry(hash_key="key2")) | |
| stats = backend.get_stats() | |
| assert stats["entry_count"] == 2 | |
| def test_get_stats_bytes_used_increases(self, backend: InMemoryBackend) -> None: | |
| """get_stats() bytes_used should increase with entries.""" | |
| stats_empty = backend.get_stats() | |
| backend.set( | |
| "key1", | |
| make_entry(hash_key="key1", original="x" * 1000), | |
| ) | |
| stats_one = backend.get_stats() | |
| backend.set( | |
| "key2", | |
| make_entry(hash_key="key2", original="y" * 1000), | |
| ) | |
| stats_two = backend.get_stats() | |
| assert stats_one["bytes_used"] > stats_empty["bytes_used"] | |
| assert stats_two["bytes_used"] > stats_one["bytes_used"] | |
| # --- Thread safety --- | |
| def test_concurrent_set_operations(self, backend: InMemoryBackend) -> None: | |
| """Backend should handle concurrent set operations safely.""" | |
| num_threads = 10 | |
| entries_per_thread = 100 | |
| errors: list[Exception] = [] | |
| def worker(thread_id: int) -> None: | |
| try: | |
| for i in range(entries_per_thread): | |
| key = f"thread{thread_id}_entry{i}" | |
| entry = make_entry(hash_key=key) | |
| backend.set(key, entry) | |
| except Exception as e: | |
| errors.append(e) | |
| threads = [threading.Thread(target=worker, args=(i,)) for i in range(num_threads)] | |
| for t in threads: | |
| t.start() | |
| for t in threads: | |
| t.join() | |
| assert len(errors) == 0 | |
| assert backend.count() == num_threads * entries_per_thread | |
| def test_concurrent_get_set_delete(self, backend: InMemoryBackend) -> None: | |
| """Backend should handle mixed concurrent operations safely.""" | |
| num_iterations = 100 | |
| errors: list[Exception] = [] | |
| # Pre-populate some entries | |
| for i in range(50): | |
| backend.set(f"key{i}", make_entry(hash_key=f"key{i}")) | |
| def setter() -> None: | |
| try: | |
| for i in range(num_iterations): | |
| backend.set(f"new_key{i}", make_entry(hash_key=f"new_key{i}")) | |
| except Exception as e: | |
| errors.append(e) | |
| def getter() -> None: | |
| try: | |
| for i in range(num_iterations): | |
| backend.get(f"key{i % 50}") | |
| except Exception as e: | |
| errors.append(e) | |
| def deleter() -> None: | |
| try: | |
| for i in range(num_iterations): | |
| backend.delete(f"key{i % 50}") | |
| except Exception as e: | |
| errors.append(e) | |
| threads = [ | |
| threading.Thread(target=setter), | |
| threading.Thread(target=setter), | |
| threading.Thread(target=getter), | |
| threading.Thread(target=getter), | |
| threading.Thread(target=deleter), | |
| ] | |
| for t in threads: | |
| t.start() | |
| for t in threads: | |
| t.join() | |
| assert len(errors) == 0 | |
| # --- Edge cases --- | |
| def test_empty_string_key(self, backend: InMemoryBackend) -> None: | |
| """Backend should handle empty string as key.""" | |
| entry = make_entry(hash_key="") | |
| backend.set("", entry) | |
| retrieved = backend.get("") | |
| assert retrieved is not None | |
| assert retrieved.hash == "" | |
| def test_unicode_content(self, backend: InMemoryBackend) -> None: | |
| """Backend should handle unicode content correctly.""" | |
| entry = make_entry( | |
| hash_key="unicode", | |
| original="日本語テスト 🎉 émojis", | |
| compressed="日本語", | |
| ) | |
| backend.set("unicode", entry) | |
| retrieved = backend.get("unicode") | |
| assert retrieved is not None | |
| assert retrieved.original_content == "日本語テスト 🎉 émojis" | |
| assert retrieved.compressed_content == "日本語" | |
| def test_large_content(self, backend: InMemoryBackend) -> None: | |
| """Backend should handle large content.""" | |
| large_content = "x" * 10_000_000 # 10MB | |
| entry = make_entry(hash_key="large", original=large_content) | |
| backend.set("large", entry) | |
| retrieved = backend.get("large") | |
| assert retrieved is not None | |
| assert len(retrieved.original_content) == 10_000_000 | |
| # --- Parameterized tests for all backend implementations --- | |
| def all_backends() -> list[Callable[[], CompressionStoreBackend]]: | |
| """Return factory functions for all backend implementations.""" | |
| return [ | |
| InMemoryBackend, | |
| # Add more backends here as they're implemented: | |
| # MongoDBBackend, | |
| # RedisBackend, | |
| ] | |
| class TestBackendContract: | |
| """Contract tests that ALL backends must pass. | |
| These tests are parameterized to run against every backend implementation. | |
| Add new backends to all_backends() to include them in these tests. | |
| """ | |
| def test_implements_protocol( | |
| self, backend_factory: Callable[[], CompressionStoreBackend] | |
| ) -> None: | |
| """All backends must implement CompressionStoreBackend protocol.""" | |
| backend = backend_factory() | |
| assert isinstance(backend, CompressionStoreBackend) | |
| def test_basic_crud_cycle(self, backend_factory: Callable[[], CompressionStoreBackend]) -> None: | |
| """All backends must support basic CRUD operations.""" | |
| backend = backend_factory() | |
| # Create | |
| entry = make_entry(hash_key="test") | |
| backend.set("test", entry) | |
| assert backend.exists("test") | |
| # Read | |
| retrieved = backend.get("test") | |
| assert retrieved is not None | |
| assert retrieved.original_content == entry.original_content | |
| # Update (overwrite) | |
| entry2 = make_entry(hash_key="test", original="updated") | |
| backend.set("test", entry2) | |
| retrieved2 = backend.get("test") | |
| assert retrieved2 is not None | |
| assert retrieved2.original_content == "updated" | |
| # Delete | |
| assert backend.delete("test") is True | |
| assert backend.exists("test") is False | |
| assert backend.get("test") is None | |
| def test_clear_works(self, backend_factory: Callable[[], CompressionStoreBackend]) -> None: | |
| """All backends must support clear().""" | |
| backend = backend_factory() | |
| backend.set("key1", make_entry(hash_key="key1")) | |
| backend.set("key2", make_entry(hash_key="key2")) | |
| assert backend.count() == 2 | |
| backend.clear() | |
| assert backend.count() == 0 | |
| def test_stats_has_required_fields( | |
| self, backend_factory: Callable[[], CompressionStoreBackend] | |
| ) -> None: | |
| """All backends must return required stats fields.""" | |
| backend = backend_factory() | |
| stats = backend.get_stats() | |
| assert "backend_type" in stats | |
| assert "entry_count" in stats | |
| assert isinstance(stats["backend_type"], str) | |
| assert isinstance(stats["entry_count"], int) | |