| """AMG-RAG-style continuous KG cache. |
| |
| Following "Agentic Medical Knowledge Graphs Enhance Medical QA" (arXiv |
| 2502.13010), every successful gemeo_lookup result is written back into |
| Aura as a `:LookupCache` node with a hash key over (case_id, query). The |
| next call with the same query hits the cache → faster, more consistent. |
| |
| The cache is also linked to the entities that appeared in the result, so |
| graph queries can surface "this lookup was useful for these patients before". |
| |
| Schema: |
| (:LookupCache {hash, query, case_id, mode, ts, hits}) |
| -[:RETRIEVED]->(:Disease|:Phenotype|:Gene|:Drug) |
| -[:USED_BY]->(:PatientSpace) |
| """ |
| from __future__ import annotations |
| import hashlib |
| import json |
| import logging |
| import time |
| from typing import Optional |
|
|
| logger = logging.getLogger("gemeo.cache") |
|
|
|
|
| def _hash(case_id: str, query: str, mode: str) -> str: |
| return hashlib.sha256(f"{case_id}|{query.strip().lower()}|{mode}".encode()).hexdigest()[:16] |
|
|
|
|
| async def _q(cypher: str, params: dict = None): |
| try: |
| from space_graph import _safe_query |
| return await _safe_query(cypher, params or {}, timeout=10.0) |
| except Exception as e: |
| logger.debug(f"cache cypher failed: {e}") |
| return [] |
|
|
|
|
| async def get(case_id: str, query: str, mode: str = "local", |
| ttl_hours: float = 24.0) -> Optional[dict]: |
| """Return cached result if exists + still fresh, else None.""" |
| h = _hash(case_id, query, mode) |
| rows = await _q(""" |
| MATCH (c:LookupCache {hash: $hash}) |
| WHERE c.ts > $cutoff |
| RETURN c.payload AS payload, c.hits AS hits |
| """, {"hash": h, "cutoff": time.time() - ttl_hours * 3600}) |
| if not rows: |
| return None |
| try: |
| |
| await _q("MATCH (c:LookupCache {hash: $hash}) SET c.hits = coalesce(c.hits, 0) + 1, c.last_hit_ts = $ts", |
| {"hash": h, "ts": time.time()}) |
| payload = json.loads(rows[0]["payload"]) |
| payload["_cache_hit"] = True |
| payload["_cache_hits_total"] = (rows[0].get("hits") or 0) + 1 |
| return payload |
| except Exception as e: |
| logger.debug(f"cache parse failed: {e}") |
| return None |
|
|
|
|
| async def put(case_id: str, query: str, mode: str, result: dict) -> bool: |
| """Store the lookup result + link to retrieved entities + the calling case.""" |
| h = _hash(case_id, query, mode) |
| payload = json.dumps(result, default=str)[:50_000] |
| try: |
| await _q(""" |
| MERGE (c:LookupCache {hash: $hash}) |
| SET c.query = $query, c.case_id = $case_id, c.mode = $mode, |
| c.ts = $ts, c.payload = $payload, c.hits = coalesce(c.hits, 0) |
| """, { |
| "hash": h, "query": query, "case_id": case_id, "mode": mode, |
| "ts": time.time(), "payload": payload, |
| }) |
| |
| for trip in (result.get("triples") or [])[:20]: |
| if not isinstance(trip, list) or len(trip) < 3: |
| continue |
| target = str(trip[2]) |
| |
| import re |
| for code_re, label, prop in [ |
| (r"HP:\d{7}", "Phenotype", "hpoId"), |
| (r"ORPHA:\d+", "Disease", "orphaCode"), |
| ]: |
| m = re.search(code_re, target) |
| if m: |
| raw_code = m.group() |
| code = raw_code.replace("ORPHA:", "") if "ORPHA" in raw_code else raw_code |
| await _q(f""" |
| MATCH (c:LookupCache {{hash: $h}}) |
| MATCH (e:{label} {{{prop}: $code}}) |
| MERGE (c)-[:RETRIEVED]->(e) |
| """, {"h": h, "code": code}) |
| |
| if case_id: |
| await _q(""" |
| MATCH (c:LookupCache {hash: $h}) |
| OPTIONAL MATCH (s:PatientSpace {space_id: $case_id}) |
| FOREACH (_ IN CASE WHEN s IS NULL THEN [] ELSE [1] END | |
| MERGE (s)-[:CACHED_LOOKUP]->(c) |
| ) |
| """, {"h": h, "case_id": case_id}) |
| return True |
| except Exception as e: |
| logger.debug(f"cache put failed: {e}") |
| return False |
|
|
|
|
| async def stats() -> dict: |
| """Cache statistics — useful for /api/gemeo/health.""" |
| rows = await _q(""" |
| MATCH (c:LookupCache) |
| RETURN count(c) AS n, sum(c.hits) AS total_hits, |
| avg(c.hits) AS avg_hits, max(c.hits) AS max_hits |
| """) |
| if not rows: |
| return {"n_entries": 0, "total_hits": 0} |
| r = rows[0] |
| return { |
| "n_entries": int(r.get("n", 0) or 0), |
| "total_hits": int(r.get("total_hits", 0) or 0), |
| "avg_hits": float(r.get("avg_hits", 0) or 0), |
| "max_hits": int(r.get("max_hits", 0) or 0), |
| } |
|
|