Spaces:
Running
Running
| """ | |
| EpiRAG - cache.py | |
| ----------------- | |
| Two-layer caching system: | |
| 1. Embedding cache - writes/loads per-paper embedding vectors to disk | |
| so ingest.py does not re-embed on every restart. | |
| 2. Result cache - SQLite-backed, keyed by normalised query hash, | |
| caches the full JSON result for identical queries. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import re | |
| import json | |
| import time | |
| import hashlib | |
| import sqlite3 | |
| import threading | |
| import numpy as np | |
| # --- Config | |
| EMBED_CACHE_DIR = "./embeddings" | |
| RESULT_CACHE_DB = "./cache.db" | |
| RESULT_TTL_SECS = 6 * 3600 # 6 hours; set to 0 to disable TTL | |
| _db_lock = threading.Lock() | |
| # --- Embedding cache | |
| def embed_cache_path(paper_name: str, chunk_index: int) -> str: | |
| safe = re.sub(r"[^a-zA-Z0-9_\-]", "_", paper_name) | |
| return os.path.join(EMBED_CACHE_DIR, f"{safe}_chunk_{chunk_index}.npy") | |
| def load_embedding(paper_name: str, chunk_index: int) -> list[float] | None: | |
| """Return cached embedding vector or None if not found.""" | |
| path = embed_cache_path(paper_name, chunk_index) | |
| if os.path.exists(path): | |
| try: | |
| return np.load(path).tolist() | |
| except Exception: | |
| return None | |
| return None | |
| def save_embedding(paper_name: str, chunk_index: int, vector: list[float]): | |
| """Persist embedding vector to disk.""" | |
| os.makedirs(EMBED_CACHE_DIR, exist_ok=True) | |
| path = embed_cache_path(paper_name, chunk_index) | |
| try: | |
| np.save(path, np.array(vector, dtype=np.float32)) | |
| except Exception as e: | |
| print(f" [EmbedCache] Failed to save {path}: {e}", flush=True) | |
| # --- Result cache (SQLite) | |
| def _get_conn() -> sqlite3.Connection: | |
| conn = sqlite3.connect(RESULT_CACHE_DB, check_same_thread=False) | |
| conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS result_cache ( | |
| query_hash TEXT PRIMARY KEY, | |
| query_text TEXT, | |
| result_json TEXT, | |
| created_at REAL | |
| ) | |
| """) | |
| conn.commit() | |
| return conn | |
| def _query_hash(question: str) -> str: | |
| """Normalise query and return its SHA-256 hex digest.""" | |
| normalised = " ".join(question.lower().split()) | |
| return hashlib.sha256(normalised.encode()).hexdigest() | |
| def get_cached_result(question: str) -> dict | None: | |
| """ | |
| Return a previously cached result dict, or None if not found / expired. | |
| """ | |
| key = _query_hash(question) | |
| try: | |
| with _db_lock: | |
| conn = _get_conn() | |
| row = conn.execute( | |
| "SELECT result_json, created_at FROM result_cache WHERE query_hash = ?", | |
| (key,) | |
| ).fetchone() | |
| conn.close() | |
| if row is None: | |
| return None | |
| result_json, created_at = row | |
| if RESULT_TTL_SECS > 0 and (time.time() - created_at) > RESULT_TTL_SECS: | |
| # expired, delete silently | |
| delete_cached_result(question) | |
| return None | |
| return json.loads(result_json) | |
| except Exception as e: | |
| print(f" [ResultCache] Read error: {e}", flush=True) | |
| return None | |
| def set_cached_result(question: str, result: dict): | |
| """ | |
| Store a result dict in the SQLite cache. | |
| Only caches successful results (mode != 'none'). | |
| """ | |
| if result.get("mode") == "none": | |
| return | |
| key = _query_hash(question) | |
| try: | |
| # Remove non-serialisable keys (sources may have complex objects) | |
| serialisable = {k: v for k, v in result.items() | |
| if k not in ("debate_rounds",)} # keep most; omit large transcripts | |
| with _db_lock: | |
| conn = _get_conn() | |
| conn.execute( | |
| """INSERT OR REPLACE INTO result_cache | |
| (query_hash, query_text, result_json, created_at) | |
| VALUES (?, ?, ?, ?)""", | |
| (key, question, json.dumps(serialisable, ensure_ascii=False), time.time()) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| except Exception as e: | |
| print(f" [ResultCache] Write error: {e}", flush=True) | |
| def delete_cached_result(question: str): | |
| key = _query_hash(question) | |
| try: | |
| with _db_lock: | |
| conn = _get_conn() | |
| conn.execute("DELETE FROM result_cache WHERE query_hash = ?", (key,)) | |
| conn.commit() | |
| conn.close() | |
| except Exception: | |
| pass | |
| def cache_stats() -> dict: | |
| """Return basic stats about the result cache.""" | |
| try: | |
| with _db_lock: | |
| conn = _get_conn() | |
| count = conn.execute("SELECT COUNT(*) FROM result_cache").fetchone()[0] | |
| oldest = conn.execute( | |
| "SELECT MIN(created_at) FROM result_cache" | |
| ).fetchone()[0] | |
| conn.close() | |
| return {"cached_results": count, "oldest_entry": oldest} | |
| except Exception: | |
| return {"cached_results": 0, "oldest_entry": None} | |