File size: 8,783 Bytes
089d665 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 | """GraphRAG over the patient gêmeo.
Inspired by Microsoft GraphRAG (2024) + LightRAG (2024) + GFM-RAG (2025).
Adapted to the rare-disease patient setting.
Two retrieval modes:
- **local** (default) — retrieve from the patient subgraph (HPO + Disease
+ Gene + Drug nodes already linked to this twin). Fast, low-cost.
- **global** — also pull cohort exemplars and PubMed cases that share
matching nodes; slower, broader.
The result is a compact JSON object the LLM can consume directly:
{
"query": "...",
"triples": [["Patient", "HAS_PHENOTYPE", "HP:0001250 (Seizure)"], ...],
"communities": [{"label": "Phenotype cluster A", "members": [...], "summary": "..."}],
"cohort_exemplars": [{"space_id": "...", "diagnosis": "...", "shared_hpos": [...]}],
"literature": [{"pmid": "...", "title": "...", "year": ...}]
}
"""
from __future__ import annotations
import logging
from typing import Optional
logger = logging.getLogger("gemeo.graphrag")
DEFAULT_K = 12
async def _safe_query(cypher: str, params: dict = None) -> list:
try:
from space_graph import _safe_query as q
return await q(cypher, params or {}, timeout=10.0)
except Exception as e:
logger.debug(f"cypher failed: {e}")
return []
def _embed_query_cached(query: str):
"""BioLORD embedding for a free-text query.
Returns a vector or None. Used for vector kNN over Disease/Phenotype
indexes when available; falls back to lexical search otherwise.
"""
try:
from biolord_normalizer import normalize_one
return normalize_one(query)
except Exception:
return None
async def _kg_lexical(query: str, limit: int = DEFAULT_K) -> list:
"""Lexical match (case-insensitive contains) — bootstrap when no embedding."""
cypher = """
MATCH (n)
WHERE (n:Disease OR n:Phenotype OR n:Gene OR n:Drug)
AND (toLower(n.name) CONTAINS toLower($q)
OR toLower(coalesce(n.cid10DescriptionPt, '')) CONTAINS toLower($q))
RETURN labels(n)[0] AS kind, n.name AS name,
coalesce(n.orphaCode, n.hpoId, n.symbol, n.rxcui) AS code
LIMIT $limit
"""
return await _safe_query(cypher, {"q": query, "limit": int(limit)})
async def _patient_triples(patient_id: str, query_terms: list[str], limit: int = DEFAULT_K) -> list:
"""Pull triples from the patient's subgraph that are relevant to the query."""
if not patient_id:
return []
# We pull the patient's connected entities and their relations to candidate diseases.
cypher = """
MATCH (s:PatientSpace {space_id: $sid})
OPTIONAL MATCH (s)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(p:Phenotype)
OPTIONAL MATCH (p)<-[:HAS_PHENOTYPE]-(d:Disease)
OPTIONAL MATCH (s)-[:HAS_HYPOTHESIS]->(h:SpaceHypothesis)-[:TARGETS]->(d2:Disease)
WITH s,
collect(DISTINCT {hpo: p.hpoId, name: p.name}) AS phenos,
collect(DISTINCT {orpha: d.orphaCode, name: d.name}) AS diseases,
collect(DISTINCT {orpha: d2.orphaCode, name: d2.name}) AS hypos
RETURN phenos, diseases, hypos
"""
rows = await _safe_query(cypher, {"sid": patient_id})
if not rows:
return []
r = rows[0]
triples = []
for p in (r.get("phenos") or [])[:limit]:
if p.get("hpo"):
triples.append([f"patient:{patient_id}", "HAS_PHENOTYPE", f"{p.get('name', '?')} ({p['hpo']})"])
for d in (r.get("diseases") or [])[:limit]:
if d.get("orpha"):
triples.append([f"disease:{d['orpha']}", "RELATED_TO_PATIENT", d.get("name") or d["orpha"]])
for h in (r.get("hypos") or [])[:limit]:
if h.get("orpha"):
triples.append([f"hypothesis", "TARGETS", f"{h.get('name', '?')} (ORPHA:{h['orpha']})"])
return triples
async def _community_summary(query_terms: list[str], limit: int = 4) -> list:
"""Cluster matched nodes into "communities" by their primary disease group.
Bootstrap version: groups Phenotype matches by their most-connected
Disease, returns the disease name as the community label.
"""
if not query_terms:
return []
cypher = """
MATCH (p:Phenotype)<-[:HAS_PHENOTYPE]-(d:Disease)
WHERE any(term IN $terms WHERE toLower(p.name) CONTAINS toLower(term))
WITH d, collect(DISTINCT p.name)[..6] AS members, count(DISTINCT p) AS n
ORDER BY n DESC
LIMIT $limit
RETURN d.name AS label, members, n
"""
rows = await _safe_query(cypher, {"terms": query_terms, "limit": int(limit)})
out = []
for r in rows:
out.append({
"label": r.get("label"),
"members": r.get("members") or [],
"summary": f"{r.get('n', 0)} fenótipos compatíveis com {r.get('label')}",
})
return out
async def _cohort_exemplars(case_id: str, query_terms: list[str], limit: int = 3) -> list:
"""Pull a few cohort cases whose phenotypes overlap with the query terms."""
if not query_terms:
return []
cypher = """
MATCH (other:PatientSpace)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(p:Phenotype)
WHERE other.space_id <> $self
AND any(term IN $terms WHERE toLower(p.name) CONTAINS toLower(term))
OPTIONAL MATCH (other)-[:HAS_HYPOTHESIS]->(h:SpaceHypothesis {status: 'confirmed'})-[:TARGETS]->(d:Disease)
WITH other, d, collect(DISTINCT p.hpoId)[..6] AS shared
RETURN other.space_id AS space_id,
d.name AS diagnosis,
shared AS shared_hpos
LIMIT $limit
"""
return await _safe_query(cypher, {
"self": case_id or "_",
"terms": query_terms,
"limit": int(limit),
})
async def _literature(query_terms: list[str], limit: int = 3) -> list:
"""Top PubMed case reports overlapping the query terms."""
if not query_terms:
return []
cypher = """
MATCH (p:Paper)
WHERE any(term IN $terms WHERE toLower(p.title) CONTAINS toLower(term))
RETURN p.pmid AS pmid, p.title AS title, p.year AS year
ORDER BY p.year DESC
LIMIT $limit
"""
return await _safe_query(cypher, {"terms": query_terms, "limit": int(limit)})
def _split_terms(query: str) -> list[str]:
# Naive: keep tokens with 4+ chars, drop stopwords.
stop = {"the", "with", "para", "tem", "uma", "está", "como", "and", "for"}
return [t for t in (w.strip(".,;:!?()[]\"'`") for w in (query or "").split())
if len(t) >= 4 and t.lower() not in stop][:8]
async def retrieve(
case_id: str,
query: str,
*,
k: int = DEFAULT_K,
mode: str = "local",
) -> dict:
"""Run GraphRAG retrieval grounded in the patient.
Args:
case_id: PatientSpace id
query: free-text query (e.g. "is this disease linked to NPC1 mutations?")
k: triples cap
mode: "local" (subgraph only) or "global" (subgraph + cohort + literature)
"""
terms = _split_terms(query)
triples = await _patient_triples(case_id, terms, limit=k)
lex_hits = await _kg_lexical(query, limit=k)
for h in lex_hits[:k]:
if h.get("name"):
triples.append([h.get("kind", "?"), "MATCHES_QUERY", f"{h['name']} ({h.get('code', '?')})"])
communities = await _community_summary(terms, limit=4)
cohort_exemplars = []
literature = []
if mode == "global":
cohort_exemplars = await _cohort_exemplars(case_id, terms, limit=3)
literature = await _literature(terms, limit=3)
return {
"query": query,
"case_id": case_id,
"mode": mode,
"triples": triples[: k * 2],
"communities": communities,
"cohort_exemplars": cohort_exemplars,
"literature": literature,
}
def format_for_llm(result: dict) -> str:
"""Render a retrieval result as a compact Markdown block for the LLM."""
if not result:
return ""
lines = [f"### Gemeo lookup — `{result.get('query', '')}`"]
if result.get("triples"):
lines.append("**Triplas relevantes:**")
for t in result["triples"][:12]:
lines.append(f"- {t[0]} → {t[1]} → {t[2]}")
if result.get("communities"):
lines.append("\n**Clusters do KG:**")
for c in result["communities"]:
lines.append(f"- {c['label']}: {c['summary']}")
if result.get("cohort_exemplars"):
lines.append("\n**Exemplares na coorte:**")
for e in result["cohort_exemplars"]:
dx = e.get("diagnosis") or "sem dx"
lines.append(f"- {e['space_id']}: {dx} (HPO em comum: {', '.join(e.get('shared_hpos', [])[:4])})")
if result.get("literature"):
lines.append("\n**Literatura:**")
for l in result["literature"]:
lines.append(f"- PMID:{l['pmid']} ({l.get('year', '?')}) — {l.get('title', '')[:100]}")
return "\n".join(lines)
|