File size: 4,781 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""AMG-RAG-style continuous KG cache.

Following "Agentic Medical Knowledge Graphs Enhance Medical QA" (arXiv
2502.13010), every successful gemeo_lookup result is written back into
Aura as a `:LookupCache` node with a hash key over (case_id, query). The
next call with the same query hits the cache → faster, more consistent.

The cache is also linked to the entities that appeared in the result, so
graph queries can surface "this lookup was useful for these patients before".

Schema:
    (:LookupCache {hash, query, case_id, mode, ts, hits})
        -[:RETRIEVED]->(:Disease|:Phenotype|:Gene|:Drug)
        -[:USED_BY]->(:PatientSpace)
"""
from __future__ import annotations
import hashlib
import json
import logging
import time
from typing import Optional

logger = logging.getLogger("gemeo.cache")


def _hash(case_id: str, query: str, mode: str) -> str:
    return hashlib.sha256(f"{case_id}|{query.strip().lower()}|{mode}".encode()).hexdigest()[:16]


async def _q(cypher: str, params: dict = None):
    try:
        from space_graph import _safe_query
        return await _safe_query(cypher, params or {}, timeout=10.0)
    except Exception as e:
        logger.debug(f"cache cypher failed: {e}")
        return []


async def get(case_id: str, query: str, mode: str = "local",
              ttl_hours: float = 24.0) -> Optional[dict]:
    """Return cached result if exists + still fresh, else None."""
    h = _hash(case_id, query, mode)
    rows = await _q("""
    MATCH (c:LookupCache {hash: $hash})
    WHERE c.ts > $cutoff
    RETURN c.payload AS payload, c.hits AS hits
    """, {"hash": h, "cutoff": time.time() - ttl_hours * 3600})
    if not rows:
        return None
    try:
        # Bump hit counter (fire-and-forget)
        await _q("MATCH (c:LookupCache {hash: $hash}) SET c.hits = coalesce(c.hits, 0) + 1, c.last_hit_ts = $ts",
                  {"hash": h, "ts": time.time()})
        payload = json.loads(rows[0]["payload"])
        payload["_cache_hit"] = True
        payload["_cache_hits_total"] = (rows[0].get("hits") or 0) + 1
        return payload
    except Exception as e:
        logger.debug(f"cache parse failed: {e}")
        return None


async def put(case_id: str, query: str, mode: str, result: dict) -> bool:
    """Store the lookup result + link to retrieved entities + the calling case."""
    h = _hash(case_id, query, mode)
    payload = json.dumps(result, default=str)[:50_000]
    try:
        await _q("""
        MERGE (c:LookupCache {hash: $hash})
        SET c.query = $query, c.case_id = $case_id, c.mode = $mode,
            c.ts = $ts, c.payload = $payload, c.hits = coalesce(c.hits, 0)
        """, {
            "hash": h, "query": query, "case_id": case_id, "mode": mode,
            "ts": time.time(), "payload": payload,
        })
        # Link to retrieved entities (best-effort; failure is non-fatal)
        for trip in (result.get("triples") or [])[:20]:
            if not isinstance(trip, list) or len(trip) < 3:
                continue
            target = str(trip[2])
            # Try to extract entity codes
            import re
            for code_re, label, prop in [
                (r"HP:\d{7}", "Phenotype", "hpoId"),
                (r"ORPHA:\d+", "Disease", "orphaCode"),
            ]:
                m = re.search(code_re, target)
                if m:
                    raw_code = m.group()
                    code = raw_code.replace("ORPHA:", "") if "ORPHA" in raw_code else raw_code
                    await _q(f"""
                    MATCH (c:LookupCache {{hash: $h}})
                    MATCH (e:{label} {{{prop}: $code}})
                    MERGE (c)-[:RETRIEVED]->(e)
                    """, {"h": h, "code": code})
        # Link to calling PatientSpace
        if case_id:
            await _q("""
            MATCH (c:LookupCache {hash: $h})
            OPTIONAL MATCH (s:PatientSpace {space_id: $case_id})
            FOREACH (_ IN CASE WHEN s IS NULL THEN [] ELSE [1] END |
                MERGE (s)-[:CACHED_LOOKUP]->(c)
            )
            """, {"h": h, "case_id": case_id})
        return True
    except Exception as e:
        logger.debug(f"cache put failed: {e}")
        return False


async def stats() -> dict:
    """Cache statistics — useful for /api/gemeo/health."""
    rows = await _q("""
    MATCH (c:LookupCache)
    RETURN count(c) AS n, sum(c.hits) AS total_hits,
           avg(c.hits) AS avg_hits, max(c.hits) AS max_hits
    """)
    if not rows:
        return {"n_entries": 0, "total_hits": 0}
    r = rows[0]
    return {
        "n_entries": int(r.get("n", 0) or 0),
        "total_hits": int(r.get("total_hits", 0) or 0),
        "avg_hits": float(r.get("avg_hits", 0) or 0),
        "max_hits": int(r.get("max_hits", 0) or 0),
    }