""" EpiRAG - cache.py ----------------- Two-layer caching system: 1. Embedding cache - writes/loads per-paper embedding vectors to disk so ingest.py does not re-embed on every restart. 2. Result cache - SQLite-backed, keyed by normalised query hash, caches the full JSON result for identical queries. """ from __future__ import annotations import os import re import json import time import hashlib import sqlite3 import threading import numpy as np # --- Config EMBED_CACHE_DIR = "./embeddings" RESULT_CACHE_DB = "./cache.db" RESULT_TTL_SECS = 6 * 3600 # 6 hours; set to 0 to disable TTL _db_lock = threading.Lock() # --- Embedding cache def embed_cache_path(paper_name: str, chunk_index: int) -> str: safe = re.sub(r"[^a-zA-Z0-9_\-]", "_", paper_name) return os.path.join(EMBED_CACHE_DIR, f"{safe}_chunk_{chunk_index}.npy") def load_embedding(paper_name: str, chunk_index: int) -> list[float] | None: """Return cached embedding vector or None if not found.""" path = embed_cache_path(paper_name, chunk_index) if os.path.exists(path): try: return np.load(path).tolist() except Exception: return None return None def save_embedding(paper_name: str, chunk_index: int, vector: list[float]): """Persist embedding vector to disk.""" os.makedirs(EMBED_CACHE_DIR, exist_ok=True) path = embed_cache_path(paper_name, chunk_index) try: np.save(path, np.array(vector, dtype=np.float32)) except Exception as e: print(f" [EmbedCache] Failed to save {path}: {e}", flush=True) # --- Result cache (SQLite) def _get_conn() -> sqlite3.Connection: conn = sqlite3.connect(RESULT_CACHE_DB, check_same_thread=False) conn.execute(""" CREATE TABLE IF NOT EXISTS result_cache ( query_hash TEXT PRIMARY KEY, query_text TEXT, result_json TEXT, created_at REAL ) """) conn.commit() return conn def _query_hash(question: str) -> str: """Normalise query and return its SHA-256 hex digest.""" normalised = " ".join(question.lower().split()) return hashlib.sha256(normalised.encode()).hexdigest() def get_cached_result(question: str) -> dict | None: """ Return a previously cached result dict, or None if not found / expired. """ key = _query_hash(question) try: with _db_lock: conn = _get_conn() row = conn.execute( "SELECT result_json, created_at FROM result_cache WHERE query_hash = ?", (key,) ).fetchone() conn.close() if row is None: return None result_json, created_at = row if RESULT_TTL_SECS > 0 and (time.time() - created_at) > RESULT_TTL_SECS: # expired, delete silently delete_cached_result(question) return None return json.loads(result_json) except Exception as e: print(f" [ResultCache] Read error: {e}", flush=True) return None def set_cached_result(question: str, result: dict): """ Store a result dict in the SQLite cache. Only caches successful results (mode != 'none'). """ if result.get("mode") == "none": return key = _query_hash(question) try: # Remove non-serialisable keys (sources may have complex objects) serialisable = {k: v for k, v in result.items() if k not in ("debate_rounds",)} # keep most; omit large transcripts with _db_lock: conn = _get_conn() conn.execute( """INSERT OR REPLACE INTO result_cache (query_hash, query_text, result_json, created_at) VALUES (?, ?, ?, ?)""", (key, question, json.dumps(serialisable, ensure_ascii=False), time.time()) ) conn.commit() conn.close() except Exception as e: print(f" [ResultCache] Write error: {e}", flush=True) def delete_cached_result(question: str): key = _query_hash(question) try: with _db_lock: conn = _get_conn() conn.execute("DELETE FROM result_cache WHERE query_hash = ?", (key,)) conn.commit() conn.close() except Exception: pass def cache_stats() -> dict: """Return basic stats about the result cache.""" try: with _db_lock: conn = _get_conn() count = conn.execute("SELECT COUNT(*) FROM result_cache").fetchone()[0] oldest = conn.execute( "SELECT MIN(created_at) FROM result_cache" ).fetchone()[0] conn.close() return {"cached_results": count, "oldest_entry": oldest} except Exception: return {"cached_results": 0, "oldest_entry": None}