File size: 4,701 Bytes
b59fc2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Disease-level result cache β€” avoids repeating RAG+LLM queries for the same disease.

Stores (symptoms_text, standard_text, sources_json) per disease in SQLite.
TTL: 7 days (medical knowledge stable; rebuild only after data update).

Impact:
  - Cold (first request for a disease):  3 LLM calls + FAISS search
  - Warm (subsequent requests):          0 LLM calls, 0 FAISS search β†’ ~5 s β†’ <0.1 s
"""
import sqlite3
import json
import time
import logging
from pathlib import Path
from typing import Optional, Tuple, List, Dict

logger = logging.getLogger(__name__)

TTL_SECONDS = 7 * 24 * 60 * 60   # 7 days


class DiseaseCache:
    def __init__(self, db_path: Optional[str] = None):
        if db_path is None:
            db_path = str(Path(__file__).parent.parent / "disease_cache.db")
        self.db_path = db_path
        self._init_db()
        logger.info(f"[DiseaseCache] SQLite cache at: {self.db_path}")

    def _conn(self) -> sqlite3.Connection:
        conn = sqlite3.connect(self.db_path, check_same_thread=False)
        conn.execute("PRAGMA journal_mode=WAL")
        conn.row_factory = sqlite3.Row
        return conn

    def _init_db(self):
        with self._conn() as c:
            c.execute("""
                CREATE TABLE IF NOT EXISTS disease_cache (
                    disease    TEXT PRIMARY KEY,
                    symptoms   TEXT NOT NULL,
                    standard   TEXT NOT NULL,
                    sources    TEXT NOT NULL DEFAULT '[]',
                    cached_at  REAL NOT NULL
                )
            """)
            c.commit()

    # ── read ──────────────────────────────────────────────────────────────────
    def get(self, disease: str) -> Optional[Dict]:
        """Return cached {symptoms, standard, sources} or None if missing/expired."""
        with self._conn() as c:
            row = c.execute(
                "SELECT symptoms, standard, sources, cached_at FROM disease_cache WHERE disease = ?",
                (disease,)
            ).fetchone()

        if row is None:
            return None

        age = time.time() - row["cached_at"]
        if age > TTL_SECONDS:
            self.invalidate(disease)
            logger.info(f"[DiseaseCache] Cache expired for '{disease}' ({age/3600:.1f}h old)")
            return None

        logger.info(f"[DiseaseCache] Cache HIT for '{disease}' ({age/3600:.1f}h old)")
        return {
            "symptoms": row["symptoms"],
            "standard": row["standard"],
            "sources":  json.loads(row["sources"]),
        }

    # ── write ─────────────────────────────────────────────────────────────────
    def set(self, disease: str, symptoms: str, standard: str, sources: List[Dict]):
        """Cache symptoms + standard for a disease."""
        now = time.time()
        sources_json = json.dumps(sources, ensure_ascii=False)
        with self._conn() as c:
            c.execute("""
                INSERT INTO disease_cache (disease, symptoms, standard, sources, cached_at)
                VALUES (?, ?, ?, ?, ?)
                ON CONFLICT(disease) DO UPDATE
                    SET symptoms=excluded.symptoms,
                        standard=excluded.standard,
                        sources=excluded.sources,
                        cached_at=excluded.cached_at
            """, (disease, symptoms, standard, sources_json, now))
            c.commit()
        logger.info(f"[DiseaseCache] Cached '{disease}'")

    # ── management ────────────────────────────────────────────────────────────
    def invalidate(self, disease: str):
        with self._conn() as c:
            c.execute("DELETE FROM disease_cache WHERE disease = ?", (disease,))
            c.commit()

    def invalidate_all(self):
        with self._conn() as c:
            c.execute("DELETE FROM disease_cache")
            c.commit()
        logger.info("[DiseaseCache] All entries invalidated")

    def stats(self) -> Dict:
        with self._conn() as c:
            total = c.execute("SELECT COUNT(*) FROM disease_cache").fetchone()[0]
            fresh = c.execute(
                "SELECT COUNT(*) FROM disease_cache WHERE cached_at > ?",
                (time.time() - TTL_SECONDS,)
            ).fetchone()[0]
        return {"total": total, "fresh": fresh, "expired": total - fresh}