gemeo-twin-stack / src /gemeo /cache.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""AMG-RAG-style continuous KG cache.
Following "Agentic Medical Knowledge Graphs Enhance Medical QA" (arXiv
2502.13010), every successful gemeo_lookup result is written back into
Aura as a `:LookupCache` node with a hash key over (case_id, query). The
next call with the same query hits the cache → faster, more consistent.
The cache is also linked to the entities that appeared in the result, so
graph queries can surface "this lookup was useful for these patients before".
Schema:
(:LookupCache {hash, query, case_id, mode, ts, hits})
-[:RETRIEVED]->(:Disease|:Phenotype|:Gene|:Drug)
-[:USED_BY]->(:PatientSpace)
"""
from __future__ import annotations
import hashlib
import json
import logging
import time
from typing import Optional
logger = logging.getLogger("gemeo.cache")
def _hash(case_id: str, query: str, mode: str) -> str:
return hashlib.sha256(f"{case_id}|{query.strip().lower()}|{mode}".encode()).hexdigest()[:16]
async def _q(cypher: str, params: dict = None):
try:
from space_graph import _safe_query
return await _safe_query(cypher, params or {}, timeout=10.0)
except Exception as e:
logger.debug(f"cache cypher failed: {e}")
return []
async def get(case_id: str, query: str, mode: str = "local",
ttl_hours: float = 24.0) -> Optional[dict]:
"""Return cached result if exists + still fresh, else None."""
h = _hash(case_id, query, mode)
rows = await _q("""
MATCH (c:LookupCache {hash: $hash})
WHERE c.ts > $cutoff
RETURN c.payload AS payload, c.hits AS hits
""", {"hash": h, "cutoff": time.time() - ttl_hours * 3600})
if not rows:
return None
try:
# Bump hit counter (fire-and-forget)
await _q("MATCH (c:LookupCache {hash: $hash}) SET c.hits = coalesce(c.hits, 0) + 1, c.last_hit_ts = $ts",
{"hash": h, "ts": time.time()})
payload = json.loads(rows[0]["payload"])
payload["_cache_hit"] = True
payload["_cache_hits_total"] = (rows[0].get("hits") or 0) + 1
return payload
except Exception as e:
logger.debug(f"cache parse failed: {e}")
return None
async def put(case_id: str, query: str, mode: str, result: dict) -> bool:
"""Store the lookup result + link to retrieved entities + the calling case."""
h = _hash(case_id, query, mode)
payload = json.dumps(result, default=str)[:50_000]
try:
await _q("""
MERGE (c:LookupCache {hash: $hash})
SET c.query = $query, c.case_id = $case_id, c.mode = $mode,
c.ts = $ts, c.payload = $payload, c.hits = coalesce(c.hits, 0)
""", {
"hash": h, "query": query, "case_id": case_id, "mode": mode,
"ts": time.time(), "payload": payload,
})
# Link to retrieved entities (best-effort; failure is non-fatal)
for trip in (result.get("triples") or [])[:20]:
if not isinstance(trip, list) or len(trip) < 3:
continue
target = str(trip[2])
# Try to extract entity codes
import re
for code_re, label, prop in [
(r"HP:\d{7}", "Phenotype", "hpoId"),
(r"ORPHA:\d+", "Disease", "orphaCode"),
]:
m = re.search(code_re, target)
if m:
raw_code = m.group()
code = raw_code.replace("ORPHA:", "") if "ORPHA" in raw_code else raw_code
await _q(f"""
MATCH (c:LookupCache {{hash: $h}})
MATCH (e:{label} {{{prop}: $code}})
MERGE (c)-[:RETRIEVED]->(e)
""", {"h": h, "code": code})
# Link to calling PatientSpace
if case_id:
await _q("""
MATCH (c:LookupCache {hash: $h})
OPTIONAL MATCH (s:PatientSpace {space_id: $case_id})
FOREACH (_ IN CASE WHEN s IS NULL THEN [] ELSE [1] END |
MERGE (s)-[:CACHED_LOOKUP]->(c)
)
""", {"h": h, "case_id": case_id})
return True
except Exception as e:
logger.debug(f"cache put failed: {e}")
return False
async def stats() -> dict:
"""Cache statistics — useful for /api/gemeo/health."""
rows = await _q("""
MATCH (c:LookupCache)
RETURN count(c) AS n, sum(c.hits) AS total_hits,
avg(c.hits) AS avg_hits, max(c.hits) AS max_hits
""")
if not rows:
return {"n_entries": 0, "total_hits": 0}
r = rows[0]
return {
"n_entries": int(r.get("n", 0) or 0),
"total_hits": int(r.get("total_hits", 0) or 0),
"avg_hits": float(r.get("avg_hits", 0) or 0),
"max_hits": int(r.get("max_hits", 0) or 0),
}