Spaces:
Running
Running
| from __future__ import annotations | |
| import os | |
| import tempfile | |
| import unittest | |
| from backend.cache_utils import PersistentCache, TTLCache | |
| class _FakeLogger: | |
| def error(self, msg: str, *args: object, **kwargs: object) -> None: | |
| return None | |
| def info(self, msg: str, *args: object, **kwargs: object) -> None: | |
| return None | |
| class CacheUtilsTests(unittest.TestCase): | |
| def test_ttl_cache_returns_defensive_copy(self) -> None: | |
| cache = TTLCache() | |
| payload = {"forecast": [{"price": 100.0}], "meta": {"source": "memory"}} | |
| cache.set("sample", payload, ttl_seconds=30) | |
| payload["forecast"][0]["price"] = 999.0 | |
| payload["meta"]["source"] = "mutated" | |
| cached_once = cache.get("sample") | |
| self.assertEqual(cached_once["forecast"][0]["price"], 100.0) | |
| self.assertEqual(cached_once["meta"]["source"], "memory") | |
| cached_once["forecast"][0]["price"] = 555.0 | |
| cached_twice = cache.get("sample") | |
| self.assertEqual(cached_twice["forecast"][0]["price"], 100.0) | |
| def test_persistent_cache_queue_copies_payload_before_enqueue(self) -> None: | |
| class FakeQueue: | |
| def __init__(self) -> None: | |
| self.item = None | |
| def put_nowait(self, item: tuple[str, object, int]) -> None: | |
| self.item = item | |
| fd, temp_path = tempfile.mkstemp(suffix=".db") | |
| os.close(fd) | |
| try: | |
| cache = PersistentCache( | |
| db_path=temp_path, | |
| cache_version_getter=lambda: "test-version", | |
| logger=_FakeLogger(), | |
| ) | |
| queue = FakeQueue() | |
| cache._queue = queue | |
| payload = {"nested": {"value": 1}} | |
| cache.set("queued", payload, ttl=60) | |
| payload["nested"]["value"] = 7 | |
| self.assertIsNotNone(queue.item) | |
| _, queued_payload, queued_ttl = queue.item | |
| self.assertEqual(queued_ttl, 60) | |
| self.assertEqual(queued_payload["nested"]["value"], 1) | |
| finally: | |
| if os.path.exists(temp_path): | |
| try: | |
| os.remove(temp_path) | |
| except PermissionError: | |
| pass | |
| def test_persistent_cache_persists_payload_across_instances(self) -> None: | |
| fd, temp_path = tempfile.mkstemp(suffix=".db") | |
| os.close(fd) | |
| try: | |
| cache = PersistentCache( | |
| db_path=temp_path, | |
| cache_version_getter=lambda: "test-version", | |
| logger=_FakeLogger(), | |
| ) | |
| payload = {"forecast": [{"price": 100.0}], "meta": {"source": "sqlite"}} | |
| cache.set("persisted", payload, ttl=60) | |
| payload["forecast"][0]["price"] = 999.0 | |
| cached_same_instance = cache.get("persisted") | |
| self.assertIsNotNone(cached_same_instance) | |
| self.assertEqual(cached_same_instance["forecast"][0]["price"], 100.0) | |
| reopened_cache = PersistentCache( | |
| db_path=temp_path, | |
| cache_version_getter=lambda: "test-version", | |
| logger=_FakeLogger(), | |
| ) | |
| cached_reopened = reopened_cache.get("persisted") | |
| self.assertIsNotNone(cached_reopened) | |
| self.assertEqual(cached_reopened["forecast"][0]["price"], 100.0) | |
| self.assertEqual(cached_reopened["meta"]["source"], "sqlite") | |
| finally: | |
| if os.path.exists(temp_path): | |
| try: | |
| os.remove(temp_path) | |
| except PermissionError: | |
| pass | |
| def test_persistent_cache_treats_version_mismatch_as_cache_miss(self) -> None: | |
| fd, temp_path = tempfile.mkstemp(suffix=".db") | |
| os.close(fd) | |
| try: | |
| cache_v1 = PersistentCache( | |
| db_path=temp_path, | |
| cache_version_getter=lambda: "v1", | |
| logger=_FakeLogger(), | |
| ) | |
| cache_v1.set("versioned", {"value": 123}, ttl=60) | |
| cache_v2 = PersistentCache( | |
| db_path=temp_path, | |
| cache_version_getter=lambda: "v2", | |
| logger=_FakeLogger(), | |
| ) | |
| self.assertIsNone(cache_v2.get("versioned")) | |
| cache_v1_reopened = PersistentCache( | |
| db_path=temp_path, | |
| cache_version_getter=lambda: "v1", | |
| logger=_FakeLogger(), | |
| ) | |
| self.assertIsNone(cache_v1_reopened.get("versioned")) | |
| finally: | |
| if os.path.exists(temp_path): | |
| try: | |
| os.remove(temp_path) | |
| except PermissionError: | |
| pass | |
| if __name__ == "__main__": | |
| unittest.main() | |