SuperAI_Forecast / backend /cache_utils.py
Thang6822
Stabilize workspace UX and deployment flow
2eec8c3
from __future__ import annotations
import asyncio
import json
import os
import sqlite3
import time
from typing import Any, Callable, Dict, Optional, Protocol, Tuple
from backend.runtime_utils import clone_cache_payload
class LoggerLike(Protocol):
def error(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
def info(self, msg: str, *args: Any, **kwargs: Any) -> None: ...
class PersistentCache:
"""SQLite-based persistent cache layer."""
def __init__(
self,
db_path: str,
cache_version_getter: Callable[[], str],
logger: LoggerLike,
) -> None:
self.db_path = db_path
self._cache_version_getter = cache_version_getter
self._logger = logger
db_dir = os.path.dirname(os.path.abspath(self.db_path))
os.makedirs(db_dir, exist_ok=True)
self._init_db()
def _init_db(self) -> None:
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS cache (
key TEXT PRIMARY KEY,
payload BLOB,
expiry REAL,
version TEXT
)
"""
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_expiry ON cache(expiry)")
def get(self, key: str) -> Optional[Any]:
try:
cache_version = self._cache_version_getter()
with sqlite3.connect(self.db_path) as conn:
cur = conn.execute(
"SELECT payload, expiry, version FROM cache WHERE key = ?",
(key,),
)
row = cur.fetchone()
if row:
payload, expiry, version = row
if time.time() < expiry and version == cache_version:
return json.loads(payload)
conn.execute("DELETE FROM cache WHERE key = ?", (key,))
except Exception as ex:
self._logger.error("[Persistence] Read error: %s", ex)
return None
def set(self, key: str, payload: Any, ttl: int) -> None:
cached_payload = clone_cache_payload(payload)
if hasattr(self, "_queue"):
self._queue.put_nowait((key, cached_payload, ttl))
return
self._write_sync(key, cached_payload, ttl)
def _write_sync(self, key: str, payload: Any, ttl: int) -> None:
try:
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"INSERT OR REPLACE INTO cache (key, payload, expiry, version) VALUES (?, ?, ?, ?)",
(key, json.dumps(payload), time.time() + ttl, self._cache_version_getter()),
)
except Exception as ex:
self._logger.error("[Persistence] Sync write error: %s", ex)
async def start_writer(self) -> None:
self._queue: asyncio.Queue[Tuple[str, Any, int]] = asyncio.Queue()
self._logger.info("[Persistence] Async writer started.")
while True:
key, payload, ttl = await self._queue.get()
try:
await asyncio.to_thread(self._write_sync, key, payload, ttl)
except Exception as ex:
self._logger.error("[Persistence] Writer loop error: %s", ex)
finally:
self._queue.task_done()
def evict(self) -> None:
try:
with sqlite3.connect(self.db_path) as conn:
conn.execute("DELETE FROM cache WHERE expiry < ?", (time.time(),))
except Exception as ex:
self._logger.error("[Persistence] Eviction error: %s", ex)
class TTLCache:
def __init__(self) -> None:
self._store: Dict[str, Tuple[float, Any]] = {}
def get(self, key: str) -> Optional[Any]:
entry = self._store.get(key)
if entry is None:
return None
exp, payload = entry
if time.time() > exp:
self._store.pop(key, None)
return None
return clone_cache_payload(payload)
def set(self, key: str, payload: Any, ttl_seconds: int) -> None:
self._store[key] = (time.time() + ttl_seconds, clone_cache_payload(payload))
def delete(self, key: str) -> bool:
return self._store.pop(key, None) is not None
def delete_by_prefix(self, prefix: str) -> int:
victims = [k for k in list(self._store) if k.startswith(prefix)]
for key in victims:
self._store.pop(key, None)
return len(victims)
def clear(self) -> int:
count = len(self._store)
self._store.clear()
return count
def evict_expired(self) -> int:
now = time.time()
expired = [key for key, (exp, _) in self._store.items() if now > exp]
for key in expired:
self._store.pop(key, None)
return len(expired)
def stats(self) -> Dict[str, Any]:
now = time.time()
alive = sum(1 for exp, _ in self._store.values() if now <= exp)
return {
"total_keys": len(self._store),
"alive_keys": alive,
"expired_keys": len(self._store) - alive,
}