""" Two-tier cache: in-memory dict (fast path) backed by SQLite (persistence). Design principles: - Thread-safe via ``threading.Lock`` for the in-memory tier and synchronous SQLite access (one connection per thread via ``check_same_thread=False`` with explicit locking). - TTL-based expiration. Callers may request stale data as a fallback when the upstream source is unreachable. - Cache keys are SHA-256 hashes of ``(tool_name, sorted_params)`` so they are stable regardless of dict ordering. """ from __future__ import annotations import hashlib import json import logging import sqlite3 import threading import time from pathlib import Path from typing import Any logger = logging.getLogger(__name__) _DEFAULT_DB_PATH = Path(__file__).resolve().parent.parent / ".cache" / "cache.db" _DEFAULT_TTL = 300.0 # 5 minutes class TieredCache: """In-memory + SQLite two-tier cache with TTL expiration.""" def __init__( self, db_path: Path | str = _DEFAULT_DB_PATH, default_ttl: float = _DEFAULT_TTL, ) -> None: self.default_ttl = default_ttl self._mem: dict[str, tuple[float, Any]] = {} # key -> (expires_at, value) self._lock = threading.Lock() self._db_path = Path(db_path) self._db_path.parent.mkdir(parents=True, exist_ok=True) self._conn = sqlite3.connect(str(self._db_path), check_same_thread=False) self._conn.execute("PRAGMA journal_mode=WAL") self._conn.execute( """ CREATE TABLE IF NOT EXISTS cache ( key TEXT PRIMARY KEY, value TEXT NOT NULL, expires REAL NOT NULL, created REAL NOT NULL ) """ ) self._conn.commit() # ------------------------------------------------------------------ # # Key generation # ------------------------------------------------------------------ # @staticmethod def make_key(tool_name: str, params: dict) -> str: """Deterministic cache key from tool name and parameters.""" raw = json.dumps({"tool": tool_name, "params": params}, sort_keys=True) return hashlib.sha256(raw.encode()).hexdigest() # ------------------------------------------------------------------ # # Public API # ------------------------------------------------------------------ # def get( self, key: str, *, allow_stale: bool = False ) -> tuple[bool, Any | None]: """Retrieve a cached value. Returns ------- (hit, value) ``hit`` is True when the value is present and not expired (or ``allow_stale`` is True and a stale value exists). """ now = time.time() # --- Memory tier --- with self._lock: entry = self._mem.get(key) if entry is not None: expires_at, value = entry if now < expires_at: return True, value if allow_stale: return True, value # Expired in memory — don't delete, let prune_expired handle it. # --- SQLite tier --- row = self._conn.execute( "SELECT value, expires FROM cache WHERE key = ?", (key,) ).fetchone() if row is not None: value = json.loads(row[0]) expires_at = row[1] if now < expires_at: # Promote to memory. self._mem[key] = (expires_at, value) return True, value if allow_stale: return True, value return False, None def set(self, key: str, value: Any, ttl: float | None = None) -> None: """Store a value in both tiers.""" ttl = ttl if ttl is not None else self.default_ttl now = time.time() expires_at = now + ttl with self._lock: self._mem[key] = (expires_at, value) serialized = json.dumps(value) with self._lock: self._conn.execute( """ INSERT INTO cache (key, value, expires, created) VALUES (?, ?, ?, ?) ON CONFLICT(key) DO UPDATE SET value=excluded.value, expires=excluded.expires, created=excluded.created """, (key, serialized, expires_at, now), ) self._conn.commit() def invalidate(self, key: str) -> None: """Remove an entry from both tiers.""" with self._lock: self._mem.pop(key, None) self._conn.execute("DELETE FROM cache WHERE key = ?", (key,)) self._conn.commit() def clear(self) -> None: """Wipe everything.""" with self._lock: self._mem.clear() self._conn.execute("DELETE FROM cache") self._conn.commit() def prune_expired(self) -> int: """Delete all expired entries from both tiers. Returns count removed.""" now = time.time() with self._lock: expired_keys = [k for k, (exp, _) in self._mem.items() if now >= exp] for k in expired_keys: del self._mem[k] cursor = self._conn.execute( "DELETE FROM cache WHERE expires <= ?", (now,) ) self._conn.commit() return len(expired_keys) + cursor.rowcount def close(self) -> None: """Close the SQLite connection.""" self._conn.close()