epirag / cache.py
RohanB67's picture
feat: literature search mode, KaTeX math, parallel citations, Zeta rewrite, search quality fixes
d42aa58
Raw
History Blame Contribute Delete
4.91 kB
"""
EpiRAG - cache.py
-----------------
Two-layer caching system:
1. Embedding cache - writes/loads per-paper embedding vectors to disk
so ingest.py does not re-embed on every restart.
2. Result cache - SQLite-backed, keyed by normalised query hash,
caches the full JSON result for identical queries.
"""
from __future__ import annotations
import os
import re
import json
import time
import hashlib
import sqlite3
import threading
import numpy as np
# --- Config
EMBED_CACHE_DIR = "./embeddings"
RESULT_CACHE_DB = "./cache.db"
RESULT_TTL_SECS = 6 * 3600 # 6 hours; set to 0 to disable TTL
_db_lock = threading.Lock()
# --- Embedding cache
def embed_cache_path(paper_name: str, chunk_index: int) -> str:
safe = re.sub(r"[^a-zA-Z0-9_\-]", "_", paper_name)
return os.path.join(EMBED_CACHE_DIR, f"{safe}_chunk_{chunk_index}.npy")
def load_embedding(paper_name: str, chunk_index: int) -> list[float] | None:
"""Return cached embedding vector or None if not found."""
path = embed_cache_path(paper_name, chunk_index)
if os.path.exists(path):
try:
return np.load(path).tolist()
except Exception:
return None
return None
def save_embedding(paper_name: str, chunk_index: int, vector: list[float]):
"""Persist embedding vector to disk."""
os.makedirs(EMBED_CACHE_DIR, exist_ok=True)
path = embed_cache_path(paper_name, chunk_index)
try:
np.save(path, np.array(vector, dtype=np.float32))
except Exception as e:
print(f" [EmbedCache] Failed to save {path}: {e}", flush=True)
# --- Result cache (SQLite)
def _get_conn() -> sqlite3.Connection:
conn = sqlite3.connect(RESULT_CACHE_DB, check_same_thread=False)
conn.execute("""
CREATE TABLE IF NOT EXISTS result_cache (
query_hash TEXT PRIMARY KEY,
query_text TEXT,
result_json TEXT,
created_at REAL
)
""")
conn.commit()
return conn
def _query_hash(question: str) -> str:
"""Normalise query and return its SHA-256 hex digest."""
normalised = " ".join(question.lower().split())
return hashlib.sha256(normalised.encode()).hexdigest()
def get_cached_result(question: str) -> dict | None:
"""
Return a previously cached result dict, or None if not found / expired.
"""
key = _query_hash(question)
try:
with _db_lock:
conn = _get_conn()
row = conn.execute(
"SELECT result_json, created_at FROM result_cache WHERE query_hash = ?",
(key,)
).fetchone()
conn.close()
if row is None:
return None
result_json, created_at = row
if RESULT_TTL_SECS > 0 and (time.time() - created_at) > RESULT_TTL_SECS:
# expired, delete silently
delete_cached_result(question)
return None
return json.loads(result_json)
except Exception as e:
print(f" [ResultCache] Read error: {e}", flush=True)
return None
def set_cached_result(question: str, result: dict):
"""
Store a result dict in the SQLite cache.
Only caches successful results (mode != 'none').
"""
if result.get("mode") == "none":
return
key = _query_hash(question)
try:
# Remove non-serialisable keys (sources may have complex objects)
serialisable = {k: v for k, v in result.items()
if k not in ("debate_rounds",)} # keep most; omit large transcripts
with _db_lock:
conn = _get_conn()
conn.execute(
"""INSERT OR REPLACE INTO result_cache
(query_hash, query_text, result_json, created_at)
VALUES (?, ?, ?, ?)""",
(key, question, json.dumps(serialisable, ensure_ascii=False), time.time())
)
conn.commit()
conn.close()
except Exception as e:
print(f" [ResultCache] Write error: {e}", flush=True)
def delete_cached_result(question: str):
key = _query_hash(question)
try:
with _db_lock:
conn = _get_conn()
conn.execute("DELETE FROM result_cache WHERE query_hash = ?", (key,))
conn.commit()
conn.close()
except Exception:
pass
def cache_stats() -> dict:
"""Return basic stats about the result cache."""
try:
with _db_lock:
conn = _get_conn()
count = conn.execute("SELECT COUNT(*) FROM result_cache").fetchone()[0]
oldest = conn.execute(
"SELECT MIN(created_at) FROM result_cache"
).fetchone()[0]
conn.close()
return {"cached_results": count, "oldest_entry": oldest}
except Exception:
return {"cached_results": 0, "oldest_entry": None}