WolfDavid's picture
Upload folder using huggingface_hub
777071b verified
"""
Two-tier cache: in-memory dict (fast path) backed by SQLite (persistence).
Design principles:
- Thread-safe via ``threading.Lock`` for the in-memory tier and synchronous
SQLite access (one connection per thread via ``check_same_thread=False``
with explicit locking).
- TTL-based expiration. Callers may request stale data as a fallback when
the upstream source is unreachable.
- Cache keys are SHA-256 hashes of ``(tool_name, sorted_params)`` so they
are stable regardless of dict ordering.
"""
from __future__ import annotations
import hashlib
import json
import logging
import sqlite3
import threading
import time
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
_DEFAULT_DB_PATH = Path(__file__).resolve().parent.parent / ".cache" / "cache.db"
_DEFAULT_TTL = 300.0 # 5 minutes
class TieredCache:
"""In-memory + SQLite two-tier cache with TTL expiration."""
def __init__(
self,
db_path: Path | str = _DEFAULT_DB_PATH,
default_ttl: float = _DEFAULT_TTL,
) -> None:
self.default_ttl = default_ttl
self._mem: dict[str, tuple[float, Any]] = {} # key -> (expires_at, value)
self._lock = threading.Lock()
self._db_path = Path(db_path)
self._db_path.parent.mkdir(parents=True, exist_ok=True)
self._conn = sqlite3.connect(str(self._db_path), check_same_thread=False)
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute(
"""
CREATE TABLE IF NOT EXISTS cache (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
expires REAL NOT NULL,
created REAL NOT NULL
)
"""
)
self._conn.commit()
# ------------------------------------------------------------------ #
# Key generation
# ------------------------------------------------------------------ #
@staticmethod
def make_key(tool_name: str, params: dict) -> str:
"""Deterministic cache key from tool name and parameters."""
raw = json.dumps({"tool": tool_name, "params": params}, sort_keys=True)
return hashlib.sha256(raw.encode()).hexdigest()
# ------------------------------------------------------------------ #
# Public API
# ------------------------------------------------------------------ #
def get(
self, key: str, *, allow_stale: bool = False
) -> tuple[bool, Any | None]:
"""Retrieve a cached value.
Returns
-------
(hit, value)
``hit`` is True when the value is present and not expired (or
``allow_stale`` is True and a stale value exists).
"""
now = time.time()
# --- Memory tier ---
with self._lock:
entry = self._mem.get(key)
if entry is not None:
expires_at, value = entry
if now < expires_at:
return True, value
if allow_stale:
return True, value
# Expired in memory — don't delete, let prune_expired handle it.
# --- SQLite tier ---
row = self._conn.execute(
"SELECT value, expires FROM cache WHERE key = ?", (key,)
).fetchone()
if row is not None:
value = json.loads(row[0])
expires_at = row[1]
if now < expires_at:
# Promote to memory.
self._mem[key] = (expires_at, value)
return True, value
if allow_stale:
return True, value
return False, None
def set(self, key: str, value: Any, ttl: float | None = None) -> None:
"""Store a value in both tiers."""
ttl = ttl if ttl is not None else self.default_ttl
now = time.time()
expires_at = now + ttl
with self._lock:
self._mem[key] = (expires_at, value)
serialized = json.dumps(value)
with self._lock:
self._conn.execute(
"""
INSERT INTO cache (key, value, expires, created)
VALUES (?, ?, ?, ?)
ON CONFLICT(key) DO UPDATE SET value=excluded.value,
expires=excluded.expires,
created=excluded.created
""",
(key, serialized, expires_at, now),
)
self._conn.commit()
def invalidate(self, key: str) -> None:
"""Remove an entry from both tiers."""
with self._lock:
self._mem.pop(key, None)
self._conn.execute("DELETE FROM cache WHERE key = ?", (key,))
self._conn.commit()
def clear(self) -> None:
"""Wipe everything."""
with self._lock:
self._mem.clear()
self._conn.execute("DELETE FROM cache")
self._conn.commit()
def prune_expired(self) -> int:
"""Delete all expired entries from both tiers. Returns count removed."""
now = time.time()
with self._lock:
expired_keys = [k for k, (exp, _) in self._mem.items() if now >= exp]
for k in expired_keys:
del self._mem[k]
cursor = self._conn.execute(
"DELETE FROM cache WHERE expires <= ?", (now,)
)
self._conn.commit()
return len(expired_keys) + cursor.rowcount
def close(self) -> None:
"""Close the SQLite connection."""
self._conn.close()