Spaces:
Sleeping
Sleeping
| """ | |
| Two-tier cache: in-memory dict (fast path) backed by SQLite (persistence). | |
| Design principles: | |
| - Thread-safe via ``threading.Lock`` for the in-memory tier and synchronous | |
| SQLite access (one connection per thread via ``check_same_thread=False`` | |
| with explicit locking). | |
| - TTL-based expiration. Callers may request stale data as a fallback when | |
| the upstream source is unreachable. | |
| - Cache keys are SHA-256 hashes of ``(tool_name, sorted_params)`` so they | |
| are stable regardless of dict ordering. | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import logging | |
| import sqlite3 | |
| import threading | |
| import time | |
| from pathlib import Path | |
| from typing import Any | |
| logger = logging.getLogger(__name__) | |
| _DEFAULT_DB_PATH = Path(__file__).resolve().parent.parent / ".cache" / "cache.db" | |
| _DEFAULT_TTL = 300.0 # 5 minutes | |
| class TieredCache: | |
| """In-memory + SQLite two-tier cache with TTL expiration.""" | |
| def __init__( | |
| self, | |
| db_path: Path | str = _DEFAULT_DB_PATH, | |
| default_ttl: float = _DEFAULT_TTL, | |
| ) -> None: | |
| self.default_ttl = default_ttl | |
| self._mem: dict[str, tuple[float, Any]] = {} # key -> (expires_at, value) | |
| self._lock = threading.Lock() | |
| self._db_path = Path(db_path) | |
| self._db_path.parent.mkdir(parents=True, exist_ok=True) | |
| self._conn = sqlite3.connect(str(self._db_path), check_same_thread=False) | |
| self._conn.execute("PRAGMA journal_mode=WAL") | |
| self._conn.execute( | |
| """ | |
| CREATE TABLE IF NOT EXISTS cache ( | |
| key TEXT PRIMARY KEY, | |
| value TEXT NOT NULL, | |
| expires REAL NOT NULL, | |
| created REAL NOT NULL | |
| ) | |
| """ | |
| ) | |
| self._conn.commit() | |
| # ------------------------------------------------------------------ # | |
| # Key generation | |
| # ------------------------------------------------------------------ # | |
| def make_key(tool_name: str, params: dict) -> str: | |
| """Deterministic cache key from tool name and parameters.""" | |
| raw = json.dumps({"tool": tool_name, "params": params}, sort_keys=True) | |
| return hashlib.sha256(raw.encode()).hexdigest() | |
| # ------------------------------------------------------------------ # | |
| # Public API | |
| # ------------------------------------------------------------------ # | |
| def get( | |
| self, key: str, *, allow_stale: bool = False | |
| ) -> tuple[bool, Any | None]: | |
| """Retrieve a cached value. | |
| Returns | |
| ------- | |
| (hit, value) | |
| ``hit`` is True when the value is present and not expired (or | |
| ``allow_stale`` is True and a stale value exists). | |
| """ | |
| now = time.time() | |
| # --- Memory tier --- | |
| with self._lock: | |
| entry = self._mem.get(key) | |
| if entry is not None: | |
| expires_at, value = entry | |
| if now < expires_at: | |
| return True, value | |
| if allow_stale: | |
| return True, value | |
| # Expired in memory — don't delete, let prune_expired handle it. | |
| # --- SQLite tier --- | |
| row = self._conn.execute( | |
| "SELECT value, expires FROM cache WHERE key = ?", (key,) | |
| ).fetchone() | |
| if row is not None: | |
| value = json.loads(row[0]) | |
| expires_at = row[1] | |
| if now < expires_at: | |
| # Promote to memory. | |
| self._mem[key] = (expires_at, value) | |
| return True, value | |
| if allow_stale: | |
| return True, value | |
| return False, None | |
| def set(self, key: str, value: Any, ttl: float | None = None) -> None: | |
| """Store a value in both tiers.""" | |
| ttl = ttl if ttl is not None else self.default_ttl | |
| now = time.time() | |
| expires_at = now + ttl | |
| with self._lock: | |
| self._mem[key] = (expires_at, value) | |
| serialized = json.dumps(value) | |
| with self._lock: | |
| self._conn.execute( | |
| """ | |
| INSERT INTO cache (key, value, expires, created) | |
| VALUES (?, ?, ?, ?) | |
| ON CONFLICT(key) DO UPDATE SET value=excluded.value, | |
| expires=excluded.expires, | |
| created=excluded.created | |
| """, | |
| (key, serialized, expires_at, now), | |
| ) | |
| self._conn.commit() | |
| def invalidate(self, key: str) -> None: | |
| """Remove an entry from both tiers.""" | |
| with self._lock: | |
| self._mem.pop(key, None) | |
| self._conn.execute("DELETE FROM cache WHERE key = ?", (key,)) | |
| self._conn.commit() | |
| def clear(self) -> None: | |
| """Wipe everything.""" | |
| with self._lock: | |
| self._mem.clear() | |
| self._conn.execute("DELETE FROM cache") | |
| self._conn.commit() | |
| def prune_expired(self) -> int: | |
| """Delete all expired entries from both tiers. Returns count removed.""" | |
| now = time.time() | |
| with self._lock: | |
| expired_keys = [k for k, (exp, _) in self._mem.items() if now >= exp] | |
| for k in expired_keys: | |
| del self._mem[k] | |
| cursor = self._conn.execute( | |
| "DELETE FROM cache WHERE expires <= ?", (now,) | |
| ) | |
| self._conn.commit() | |
| return len(expired_keys) + cursor.rowcount | |
| def close(self) -> None: | |
| """Close the SQLite connection.""" | |
| self._conn.close() | |