| """GraphRAG over the patient gêmeo. |
| |
| Inspired by Microsoft GraphRAG (2024) + LightRAG (2024) + GFM-RAG (2025). |
| Adapted to the rare-disease patient setting. |
| |
| Two retrieval modes: |
| - **local** (default) — retrieve from the patient subgraph (HPO + Disease |
| + Gene + Drug nodes already linked to this twin). Fast, low-cost. |
| - **global** — also pull cohort exemplars and PubMed cases that share |
| matching nodes; slower, broader. |
| |
| The result is a compact JSON object the LLM can consume directly: |
| { |
| "query": "...", |
| "triples": [["Patient", "HAS_PHENOTYPE", "HP:0001250 (Seizure)"], ...], |
| "communities": [{"label": "Phenotype cluster A", "members": [...], "summary": "..."}], |
| "cohort_exemplars": [{"space_id": "...", "diagnosis": "...", "shared_hpos": [...]}], |
| "literature": [{"pmid": "...", "title": "...", "year": ...}] |
| } |
| """ |
| from __future__ import annotations |
| import logging |
| from typing import Optional |
|
|
| logger = logging.getLogger("gemeo.graphrag") |
|
|
| DEFAULT_K = 12 |
|
|
|
|
| async def _safe_query(cypher: str, params: dict = None) -> list: |
| try: |
| from space_graph import _safe_query as q |
| return await q(cypher, params or {}, timeout=10.0) |
| except Exception as e: |
| logger.debug(f"cypher failed: {e}") |
| return [] |
|
|
|
|
| def _embed_query_cached(query: str): |
| """BioLORD embedding for a free-text query. |
| |
| Returns a vector or None. Used for vector kNN over Disease/Phenotype |
| indexes when available; falls back to lexical search otherwise. |
| """ |
| try: |
| from biolord_normalizer import normalize_one |
| return normalize_one(query) |
| except Exception: |
| return None |
|
|
|
|
| async def _kg_lexical(query: str, limit: int = DEFAULT_K) -> list: |
| """Lexical match (case-insensitive contains) — bootstrap when no embedding.""" |
| cypher = """ |
| MATCH (n) |
| WHERE (n:Disease OR n:Phenotype OR n:Gene OR n:Drug) |
| AND (toLower(n.name) CONTAINS toLower($q) |
| OR toLower(coalesce(n.cid10DescriptionPt, '')) CONTAINS toLower($q)) |
| RETURN labels(n)[0] AS kind, n.name AS name, |
| coalesce(n.orphaCode, n.hpoId, n.symbol, n.rxcui) AS code |
| LIMIT $limit |
| """ |
| return await _safe_query(cypher, {"q": query, "limit": int(limit)}) |
|
|
|
|
| async def _patient_triples(patient_id: str, query_terms: list[str], limit: int = DEFAULT_K) -> list: |
| """Pull triples from the patient's subgraph that are relevant to the query.""" |
| if not patient_id: |
| return [] |
| |
| cypher = """ |
| MATCH (s:PatientSpace {space_id: $sid}) |
| OPTIONAL MATCH (s)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(p:Phenotype) |
| OPTIONAL MATCH (p)<-[:HAS_PHENOTYPE]-(d:Disease) |
| OPTIONAL MATCH (s)-[:HAS_HYPOTHESIS]->(h:SpaceHypothesis)-[:TARGETS]->(d2:Disease) |
| WITH s, |
| collect(DISTINCT {hpo: p.hpoId, name: p.name}) AS phenos, |
| collect(DISTINCT {orpha: d.orphaCode, name: d.name}) AS diseases, |
| collect(DISTINCT {orpha: d2.orphaCode, name: d2.name}) AS hypos |
| RETURN phenos, diseases, hypos |
| """ |
| rows = await _safe_query(cypher, {"sid": patient_id}) |
| if not rows: |
| return [] |
| r = rows[0] |
| triples = [] |
| for p in (r.get("phenos") or [])[:limit]: |
| if p.get("hpo"): |
| triples.append([f"patient:{patient_id}", "HAS_PHENOTYPE", f"{p.get('name', '?')} ({p['hpo']})"]) |
| for d in (r.get("diseases") or [])[:limit]: |
| if d.get("orpha"): |
| triples.append([f"disease:{d['orpha']}", "RELATED_TO_PATIENT", d.get("name") or d["orpha"]]) |
| for h in (r.get("hypos") or [])[:limit]: |
| if h.get("orpha"): |
| triples.append([f"hypothesis", "TARGETS", f"{h.get('name', '?')} (ORPHA:{h['orpha']})"]) |
| return triples |
|
|
|
|
| async def _community_summary(query_terms: list[str], limit: int = 4) -> list: |
| """Cluster matched nodes into "communities" by their primary disease group. |
| |
| Bootstrap version: groups Phenotype matches by their most-connected |
| Disease, returns the disease name as the community label. |
| """ |
| if not query_terms: |
| return [] |
| cypher = """ |
| MATCH (p:Phenotype)<-[:HAS_PHENOTYPE]-(d:Disease) |
| WHERE any(term IN $terms WHERE toLower(p.name) CONTAINS toLower(term)) |
| WITH d, collect(DISTINCT p.name)[..6] AS members, count(DISTINCT p) AS n |
| ORDER BY n DESC |
| LIMIT $limit |
| RETURN d.name AS label, members, n |
| """ |
| rows = await _safe_query(cypher, {"terms": query_terms, "limit": int(limit)}) |
| out = [] |
| for r in rows: |
| out.append({ |
| "label": r.get("label"), |
| "members": r.get("members") or [], |
| "summary": f"{r.get('n', 0)} fenótipos compatíveis com {r.get('label')}", |
| }) |
| return out |
|
|
|
|
| async def _cohort_exemplars(case_id: str, query_terms: list[str], limit: int = 3) -> list: |
| """Pull a few cohort cases whose phenotypes overlap with the query terms.""" |
| if not query_terms: |
| return [] |
| cypher = """ |
| MATCH (other:PatientSpace)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(p:Phenotype) |
| WHERE other.space_id <> $self |
| AND any(term IN $terms WHERE toLower(p.name) CONTAINS toLower(term)) |
| OPTIONAL MATCH (other)-[:HAS_HYPOTHESIS]->(h:SpaceHypothesis {status: 'confirmed'})-[:TARGETS]->(d:Disease) |
| WITH other, d, collect(DISTINCT p.hpoId)[..6] AS shared |
| RETURN other.space_id AS space_id, |
| d.name AS diagnosis, |
| shared AS shared_hpos |
| LIMIT $limit |
| """ |
| return await _safe_query(cypher, { |
| "self": case_id or "_", |
| "terms": query_terms, |
| "limit": int(limit), |
| }) |
|
|
|
|
| async def _literature(query_terms: list[str], limit: int = 3) -> list: |
| """Top PubMed case reports overlapping the query terms.""" |
| if not query_terms: |
| return [] |
| cypher = """ |
| MATCH (p:Paper) |
| WHERE any(term IN $terms WHERE toLower(p.title) CONTAINS toLower(term)) |
| RETURN p.pmid AS pmid, p.title AS title, p.year AS year |
| ORDER BY p.year DESC |
| LIMIT $limit |
| """ |
| return await _safe_query(cypher, {"terms": query_terms, "limit": int(limit)}) |
|
|
|
|
| def _split_terms(query: str) -> list[str]: |
| |
| stop = {"the", "with", "para", "tem", "uma", "está", "como", "and", "for"} |
| return [t for t in (w.strip(".,;:!?()[]\"'`") for w in (query or "").split()) |
| if len(t) >= 4 and t.lower() not in stop][:8] |
|
|
|
|
| async def retrieve( |
| case_id: str, |
| query: str, |
| *, |
| k: int = DEFAULT_K, |
| mode: str = "local", |
| ) -> dict: |
| """Run GraphRAG retrieval grounded in the patient. |
| |
| Args: |
| case_id: PatientSpace id |
| query: free-text query (e.g. "is this disease linked to NPC1 mutations?") |
| k: triples cap |
| mode: "local" (subgraph only) or "global" (subgraph + cohort + literature) |
| """ |
| terms = _split_terms(query) |
|
|
| triples = await _patient_triples(case_id, terms, limit=k) |
| lex_hits = await _kg_lexical(query, limit=k) |
| for h in lex_hits[:k]: |
| if h.get("name"): |
| triples.append([h.get("kind", "?"), "MATCHES_QUERY", f"{h['name']} ({h.get('code', '?')})"]) |
|
|
| communities = await _community_summary(terms, limit=4) |
|
|
| cohort_exemplars = [] |
| literature = [] |
| if mode == "global": |
| cohort_exemplars = await _cohort_exemplars(case_id, terms, limit=3) |
| literature = await _literature(terms, limit=3) |
|
|
| return { |
| "query": query, |
| "case_id": case_id, |
| "mode": mode, |
| "triples": triples[: k * 2], |
| "communities": communities, |
| "cohort_exemplars": cohort_exemplars, |
| "literature": literature, |
| } |
|
|
|
|
| def format_for_llm(result: dict) -> str: |
| """Render a retrieval result as a compact Markdown block for the LLM.""" |
| if not result: |
| return "" |
| lines = [f"### Gemeo lookup — `{result.get('query', '')}`"] |
| if result.get("triples"): |
| lines.append("**Triplas relevantes:**") |
| for t in result["triples"][:12]: |
| lines.append(f"- {t[0]} → {t[1]} → {t[2]}") |
| if result.get("communities"): |
| lines.append("\n**Clusters do KG:**") |
| for c in result["communities"]: |
| lines.append(f"- {c['label']}: {c['summary']}") |
| if result.get("cohort_exemplars"): |
| lines.append("\n**Exemplares na coorte:**") |
| for e in result["cohort_exemplars"]: |
| dx = e.get("diagnosis") or "sem dx" |
| lines.append(f"- {e['space_id']}: {dx} (HPO em comum: {', '.join(e.get('shared_hpos', [])[:4])})") |
| if result.get("literature"): |
| lines.append("\n**Literatura:**") |
| for l in result["literature"]: |
| lines.append(f"- PMID:{l['pmid']} ({l.get('year', '?')}) — {l.get('title', '')[:100]}") |
| return "\n".join(lines) |
|
|