gemeo-twin-stack / src /gemeo /graphrag.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""GraphRAG over the patient gêmeo.
Inspired by Microsoft GraphRAG (2024) + LightRAG (2024) + GFM-RAG (2025).
Adapted to the rare-disease patient setting.
Two retrieval modes:
- **local** (default) — retrieve from the patient subgraph (HPO + Disease
+ Gene + Drug nodes already linked to this twin). Fast, low-cost.
- **global** — also pull cohort exemplars and PubMed cases that share
matching nodes; slower, broader.
The result is a compact JSON object the LLM can consume directly:
{
"query": "...",
"triples": [["Patient", "HAS_PHENOTYPE", "HP:0001250 (Seizure)"], ...],
"communities": [{"label": "Phenotype cluster A", "members": [...], "summary": "..."}],
"cohort_exemplars": [{"space_id": "...", "diagnosis": "...", "shared_hpos": [...]}],
"literature": [{"pmid": "...", "title": "...", "year": ...}]
}
"""
from __future__ import annotations
import logging
from typing import Optional
logger = logging.getLogger("gemeo.graphrag")
DEFAULT_K = 12
async def _safe_query(cypher: str, params: dict = None) -> list:
try:
from space_graph import _safe_query as q
return await q(cypher, params or {}, timeout=10.0)
except Exception as e:
logger.debug(f"cypher failed: {e}")
return []
def _embed_query_cached(query: str):
"""BioLORD embedding for a free-text query.
Returns a vector or None. Used for vector kNN over Disease/Phenotype
indexes when available; falls back to lexical search otherwise.
"""
try:
from biolord_normalizer import normalize_one
return normalize_one(query)
except Exception:
return None
async def _kg_lexical(query: str, limit: int = DEFAULT_K) -> list:
"""Lexical match (case-insensitive contains) — bootstrap when no embedding."""
cypher = """
MATCH (n)
WHERE (n:Disease OR n:Phenotype OR n:Gene OR n:Drug)
AND (toLower(n.name) CONTAINS toLower($q)
OR toLower(coalesce(n.cid10DescriptionPt, '')) CONTAINS toLower($q))
RETURN labels(n)[0] AS kind, n.name AS name,
coalesce(n.orphaCode, n.hpoId, n.symbol, n.rxcui) AS code
LIMIT $limit
"""
return await _safe_query(cypher, {"q": query, "limit": int(limit)})
async def _patient_triples(patient_id: str, query_terms: list[str], limit: int = DEFAULT_K) -> list:
"""Pull triples from the patient's subgraph that are relevant to the query."""
if not patient_id:
return []
# We pull the patient's connected entities and their relations to candidate diseases.
cypher = """
MATCH (s:PatientSpace {space_id: $sid})
OPTIONAL MATCH (s)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(p:Phenotype)
OPTIONAL MATCH (p)<-[:HAS_PHENOTYPE]-(d:Disease)
OPTIONAL MATCH (s)-[:HAS_HYPOTHESIS]->(h:SpaceHypothesis)-[:TARGETS]->(d2:Disease)
WITH s,
collect(DISTINCT {hpo: p.hpoId, name: p.name}) AS phenos,
collect(DISTINCT {orpha: d.orphaCode, name: d.name}) AS diseases,
collect(DISTINCT {orpha: d2.orphaCode, name: d2.name}) AS hypos
RETURN phenos, diseases, hypos
"""
rows = await _safe_query(cypher, {"sid": patient_id})
if not rows:
return []
r = rows[0]
triples = []
for p in (r.get("phenos") or [])[:limit]:
if p.get("hpo"):
triples.append([f"patient:{patient_id}", "HAS_PHENOTYPE", f"{p.get('name', '?')} ({p['hpo']})"])
for d in (r.get("diseases") or [])[:limit]:
if d.get("orpha"):
triples.append([f"disease:{d['orpha']}", "RELATED_TO_PATIENT", d.get("name") or d["orpha"]])
for h in (r.get("hypos") or [])[:limit]:
if h.get("orpha"):
triples.append([f"hypothesis", "TARGETS", f"{h.get('name', '?')} (ORPHA:{h['orpha']})"])
return triples
async def _community_summary(query_terms: list[str], limit: int = 4) -> list:
"""Cluster matched nodes into "communities" by their primary disease group.
Bootstrap version: groups Phenotype matches by their most-connected
Disease, returns the disease name as the community label.
"""
if not query_terms:
return []
cypher = """
MATCH (p:Phenotype)<-[:HAS_PHENOTYPE]-(d:Disease)
WHERE any(term IN $terms WHERE toLower(p.name) CONTAINS toLower(term))
WITH d, collect(DISTINCT p.name)[..6] AS members, count(DISTINCT p) AS n
ORDER BY n DESC
LIMIT $limit
RETURN d.name AS label, members, n
"""
rows = await _safe_query(cypher, {"terms": query_terms, "limit": int(limit)})
out = []
for r in rows:
out.append({
"label": r.get("label"),
"members": r.get("members") or [],
"summary": f"{r.get('n', 0)} fenótipos compatíveis com {r.get('label')}",
})
return out
async def _cohort_exemplars(case_id: str, query_terms: list[str], limit: int = 3) -> list:
"""Pull a few cohort cases whose phenotypes overlap with the query terms."""
if not query_terms:
return []
cypher = """
MATCH (other:PatientSpace)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(p:Phenotype)
WHERE other.space_id <> $self
AND any(term IN $terms WHERE toLower(p.name) CONTAINS toLower(term))
OPTIONAL MATCH (other)-[:HAS_HYPOTHESIS]->(h:SpaceHypothesis {status: 'confirmed'})-[:TARGETS]->(d:Disease)
WITH other, d, collect(DISTINCT p.hpoId)[..6] AS shared
RETURN other.space_id AS space_id,
d.name AS diagnosis,
shared AS shared_hpos
LIMIT $limit
"""
return await _safe_query(cypher, {
"self": case_id or "_",
"terms": query_terms,
"limit": int(limit),
})
async def _literature(query_terms: list[str], limit: int = 3) -> list:
"""Top PubMed case reports overlapping the query terms."""
if not query_terms:
return []
cypher = """
MATCH (p:Paper)
WHERE any(term IN $terms WHERE toLower(p.title) CONTAINS toLower(term))
RETURN p.pmid AS pmid, p.title AS title, p.year AS year
ORDER BY p.year DESC
LIMIT $limit
"""
return await _safe_query(cypher, {"terms": query_terms, "limit": int(limit)})
def _split_terms(query: str) -> list[str]:
# Naive: keep tokens with 4+ chars, drop stopwords.
stop = {"the", "with", "para", "tem", "uma", "está", "como", "and", "for"}
return [t for t in (w.strip(".,;:!?()[]\"'`") for w in (query or "").split())
if len(t) >= 4 and t.lower() not in stop][:8]
async def retrieve(
case_id: str,
query: str,
*,
k: int = DEFAULT_K,
mode: str = "local",
) -> dict:
"""Run GraphRAG retrieval grounded in the patient.
Args:
case_id: PatientSpace id
query: free-text query (e.g. "is this disease linked to NPC1 mutations?")
k: triples cap
mode: "local" (subgraph only) or "global" (subgraph + cohort + literature)
"""
terms = _split_terms(query)
triples = await _patient_triples(case_id, terms, limit=k)
lex_hits = await _kg_lexical(query, limit=k)
for h in lex_hits[:k]:
if h.get("name"):
triples.append([h.get("kind", "?"), "MATCHES_QUERY", f"{h['name']} ({h.get('code', '?')})"])
communities = await _community_summary(terms, limit=4)
cohort_exemplars = []
literature = []
if mode == "global":
cohort_exemplars = await _cohort_exemplars(case_id, terms, limit=3)
literature = await _literature(terms, limit=3)
return {
"query": query,
"case_id": case_id,
"mode": mode,
"triples": triples[: k * 2],
"communities": communities,
"cohort_exemplars": cohort_exemplars,
"literature": literature,
}
def format_for_llm(result: dict) -> str:
"""Render a retrieval result as a compact Markdown block for the LLM."""
if not result:
return ""
lines = [f"### Gemeo lookup — `{result.get('query', '')}`"]
if result.get("triples"):
lines.append("**Triplas relevantes:**")
for t in result["triples"][:12]:
lines.append(f"- {t[0]}{t[1]}{t[2]}")
if result.get("communities"):
lines.append("\n**Clusters do KG:**")
for c in result["communities"]:
lines.append(f"- {c['label']}: {c['summary']}")
if result.get("cohort_exemplars"):
lines.append("\n**Exemplares na coorte:**")
for e in result["cohort_exemplars"]:
dx = e.get("diagnosis") or "sem dx"
lines.append(f"- {e['space_id']}: {dx} (HPO em comum: {', '.join(e.get('shared_hpos', [])[:4])})")
if result.get("literature"):
lines.append("\n**Literatura:**")
for l in result["literature"]:
lines.append(f"- PMID:{l['pmid']} ({l.get('year', '?')}) — {l.get('title', '')[:100]}")
return "\n".join(lines)