Spaces:
Sleeping
Sleeping
File size: 5,674 Bytes
777071b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | """
Two-tier cache: in-memory dict (fast path) backed by SQLite (persistence).
Design principles:
- Thread-safe via ``threading.Lock`` for the in-memory tier and synchronous
SQLite access (one connection per thread via ``check_same_thread=False``
with explicit locking).
- TTL-based expiration. Callers may request stale data as a fallback when
the upstream source is unreachable.
- Cache keys are SHA-256 hashes of ``(tool_name, sorted_params)`` so they
are stable regardless of dict ordering.
"""
from __future__ import annotations
import hashlib
import json
import logging
import sqlite3
import threading
import time
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
_DEFAULT_DB_PATH = Path(__file__).resolve().parent.parent / ".cache" / "cache.db"
_DEFAULT_TTL = 300.0 # 5 minutes
class TieredCache:
"""In-memory + SQLite two-tier cache with TTL expiration."""
def __init__(
self,
db_path: Path | str = _DEFAULT_DB_PATH,
default_ttl: float = _DEFAULT_TTL,
) -> None:
self.default_ttl = default_ttl
self._mem: dict[str, tuple[float, Any]] = {} # key -> (expires_at, value)
self._lock = threading.Lock()
self._db_path = Path(db_path)
self._db_path.parent.mkdir(parents=True, exist_ok=True)
self._conn = sqlite3.connect(str(self._db_path), check_same_thread=False)
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute(
"""
CREATE TABLE IF NOT EXISTS cache (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
expires REAL NOT NULL,
created REAL NOT NULL
)
"""
)
self._conn.commit()
# ------------------------------------------------------------------ #
# Key generation
# ------------------------------------------------------------------ #
@staticmethod
def make_key(tool_name: str, params: dict) -> str:
"""Deterministic cache key from tool name and parameters."""
raw = json.dumps({"tool": tool_name, "params": params}, sort_keys=True)
return hashlib.sha256(raw.encode()).hexdigest()
# ------------------------------------------------------------------ #
# Public API
# ------------------------------------------------------------------ #
def get(
self, key: str, *, allow_stale: bool = False
) -> tuple[bool, Any | None]:
"""Retrieve a cached value.
Returns
-------
(hit, value)
``hit`` is True when the value is present and not expired (or
``allow_stale`` is True and a stale value exists).
"""
now = time.time()
# --- Memory tier ---
with self._lock:
entry = self._mem.get(key)
if entry is not None:
expires_at, value = entry
if now < expires_at:
return True, value
if allow_stale:
return True, value
# Expired in memory — don't delete, let prune_expired handle it.
# --- SQLite tier ---
row = self._conn.execute(
"SELECT value, expires FROM cache WHERE key = ?", (key,)
).fetchone()
if row is not None:
value = json.loads(row[0])
expires_at = row[1]
if now < expires_at:
# Promote to memory.
self._mem[key] = (expires_at, value)
return True, value
if allow_stale:
return True, value
return False, None
def set(self, key: str, value: Any, ttl: float | None = None) -> None:
"""Store a value in both tiers."""
ttl = ttl if ttl is not None else self.default_ttl
now = time.time()
expires_at = now + ttl
with self._lock:
self._mem[key] = (expires_at, value)
serialized = json.dumps(value)
with self._lock:
self._conn.execute(
"""
INSERT INTO cache (key, value, expires, created)
VALUES (?, ?, ?, ?)
ON CONFLICT(key) DO UPDATE SET value=excluded.value,
expires=excluded.expires,
created=excluded.created
""",
(key, serialized, expires_at, now),
)
self._conn.commit()
def invalidate(self, key: str) -> None:
"""Remove an entry from both tiers."""
with self._lock:
self._mem.pop(key, None)
self._conn.execute("DELETE FROM cache WHERE key = ?", (key,))
self._conn.commit()
def clear(self) -> None:
"""Wipe everything."""
with self._lock:
self._mem.clear()
self._conn.execute("DELETE FROM cache")
self._conn.commit()
def prune_expired(self) -> int:
"""Delete all expired entries from both tiers. Returns count removed."""
now = time.time()
with self._lock:
expired_keys = [k for k, (exp, _) in self._mem.items() if now >= exp]
for k in expired_keys:
del self._mem[k]
cursor = self._conn.execute(
"DELETE FROM cache WHERE expires <= ?", (now,)
)
self._conn.commit()
return len(expired_keys) + cursor.rowcount
def close(self) -> None:
"""Close the SQLite connection."""
self._conn.close()
|