File size: 5,674 Bytes
777071b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
Two-tier cache: in-memory dict (fast path) backed by SQLite (persistence).

Design principles:
 - Thread-safe via ``threading.Lock`` for the in-memory tier and synchronous
   SQLite access (one connection per thread via ``check_same_thread=False``
   with explicit locking).
 - TTL-based expiration.  Callers may request stale data as a fallback when
   the upstream source is unreachable.
 - Cache keys are SHA-256 hashes of ``(tool_name, sorted_params)`` so they
   are stable regardless of dict ordering.
"""

from __future__ import annotations

import hashlib
import json
import logging
import sqlite3
import threading
import time
from pathlib import Path
from typing import Any

logger = logging.getLogger(__name__)

_DEFAULT_DB_PATH = Path(__file__).resolve().parent.parent / ".cache" / "cache.db"
_DEFAULT_TTL = 300.0  # 5 minutes


class TieredCache:
    """In-memory + SQLite two-tier cache with TTL expiration."""

    def __init__(
        self,
        db_path: Path | str = _DEFAULT_DB_PATH,
        default_ttl: float = _DEFAULT_TTL,
    ) -> None:
        self.default_ttl = default_ttl
        self._mem: dict[str, tuple[float, Any]] = {}  # key -> (expires_at, value)
        self._lock = threading.Lock()

        self._db_path = Path(db_path)
        self._db_path.parent.mkdir(parents=True, exist_ok=True)
        self._conn = sqlite3.connect(str(self._db_path), check_same_thread=False)
        self._conn.execute("PRAGMA journal_mode=WAL")
        self._conn.execute(
            """
            CREATE TABLE IF NOT EXISTS cache (
                key       TEXT PRIMARY KEY,
                value     TEXT NOT NULL,
                expires   REAL NOT NULL,
                created   REAL NOT NULL
            )
            """
        )
        self._conn.commit()

    # ------------------------------------------------------------------ #
    # Key generation
    # ------------------------------------------------------------------ #
    @staticmethod
    def make_key(tool_name: str, params: dict) -> str:
        """Deterministic cache key from tool name and parameters."""
        raw = json.dumps({"tool": tool_name, "params": params}, sort_keys=True)
        return hashlib.sha256(raw.encode()).hexdigest()

    # ------------------------------------------------------------------ #
    # Public API
    # ------------------------------------------------------------------ #
    def get(
        self, key: str, *, allow_stale: bool = False
    ) -> tuple[bool, Any | None]:
        """Retrieve a cached value.

        Returns
        -------
        (hit, value)
            ``hit`` is True when the value is present and not expired (or
            ``allow_stale`` is True and a stale value exists).
        """
        now = time.time()

        # --- Memory tier ---
        with self._lock:
            entry = self._mem.get(key)
            if entry is not None:
                expires_at, value = entry
                if now < expires_at:
                    return True, value
                if allow_stale:
                    return True, value
                # Expired in memory — don't delete, let prune_expired handle it.

            # --- SQLite tier ---
            row = self._conn.execute(
                "SELECT value, expires FROM cache WHERE key = ?", (key,)
            ).fetchone()
            if row is not None:
                value = json.loads(row[0])
                expires_at = row[1]
                if now < expires_at:
                    # Promote to memory.
                    self._mem[key] = (expires_at, value)
                    return True, value
                if allow_stale:
                    return True, value

        return False, None

    def set(self, key: str, value: Any, ttl: float | None = None) -> None:
        """Store a value in both tiers."""
        ttl = ttl if ttl is not None else self.default_ttl
        now = time.time()
        expires_at = now + ttl

        with self._lock:
            self._mem[key] = (expires_at, value)

        serialized = json.dumps(value)
        with self._lock:
            self._conn.execute(
                """
                INSERT INTO cache (key, value, expires, created)
                VALUES (?, ?, ?, ?)
                ON CONFLICT(key) DO UPDATE SET value=excluded.value,
                                               expires=excluded.expires,
                                               created=excluded.created
                """,
                (key, serialized, expires_at, now),
            )
            self._conn.commit()

    def invalidate(self, key: str) -> None:
        """Remove an entry from both tiers."""
        with self._lock:
            self._mem.pop(key, None)
            self._conn.execute("DELETE FROM cache WHERE key = ?", (key,))
            self._conn.commit()

    def clear(self) -> None:
        """Wipe everything."""
        with self._lock:
            self._mem.clear()
            self._conn.execute("DELETE FROM cache")
            self._conn.commit()

    def prune_expired(self) -> int:
        """Delete all expired entries from both tiers.  Returns count removed."""
        now = time.time()

        with self._lock:
            expired_keys = [k for k, (exp, _) in self._mem.items() if now >= exp]
            for k in expired_keys:
                del self._mem[k]

            cursor = self._conn.execute(
                "DELETE FROM cache WHERE expires <= ?", (now,)
            )
            self._conn.commit()
        return len(expired_keys) + cursor.rowcount

    def close(self) -> None:
        """Close the SQLite connection."""
        self._conn.close()