"""GraphRAG over the patient gêmeo. Inspired by Microsoft GraphRAG (2024) + LightRAG (2024) + GFM-RAG (2025). Adapted to the rare-disease patient setting. Two retrieval modes: - **local** (default) — retrieve from the patient subgraph (HPO + Disease + Gene + Drug nodes already linked to this twin). Fast, low-cost. - **global** — also pull cohort exemplars and PubMed cases that share matching nodes; slower, broader. The result is a compact JSON object the LLM can consume directly: { "query": "...", "triples": [["Patient", "HAS_PHENOTYPE", "HP:0001250 (Seizure)"], ...], "communities": [{"label": "Phenotype cluster A", "members": [...], "summary": "..."}], "cohort_exemplars": [{"space_id": "...", "diagnosis": "...", "shared_hpos": [...]}], "literature": [{"pmid": "...", "title": "...", "year": ...}] } """ from __future__ import annotations import logging from typing import Optional logger = logging.getLogger("gemeo.graphrag") DEFAULT_K = 12 async def _safe_query(cypher: str, params: dict = None) -> list: try: from space_graph import _safe_query as q return await q(cypher, params or {}, timeout=10.0) except Exception as e: logger.debug(f"cypher failed: {e}") return [] def _embed_query_cached(query: str): """BioLORD embedding for a free-text query. Returns a vector or None. Used for vector kNN over Disease/Phenotype indexes when available; falls back to lexical search otherwise. """ try: from biolord_normalizer import normalize_one return normalize_one(query) except Exception: return None async def _kg_lexical(query: str, limit: int = DEFAULT_K) -> list: """Lexical match (case-insensitive contains) — bootstrap when no embedding.""" cypher = """ MATCH (n) WHERE (n:Disease OR n:Phenotype OR n:Gene OR n:Drug) AND (toLower(n.name) CONTAINS toLower($q) OR toLower(coalesce(n.cid10DescriptionPt, '')) CONTAINS toLower($q)) RETURN labels(n)[0] AS kind, n.name AS name, coalesce(n.orphaCode, n.hpoId, n.symbol, n.rxcui) AS code LIMIT $limit """ return await _safe_query(cypher, {"q": query, "limit": int(limit)}) async def _patient_triples(patient_id: str, query_terms: list[str], limit: int = DEFAULT_K) -> list: """Pull triples from the patient's subgraph that are relevant to the query.""" if not patient_id: return [] # We pull the patient's connected entities and their relations to candidate diseases. cypher = """ MATCH (s:PatientSpace {space_id: $sid}) OPTIONAL MATCH (s)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(p:Phenotype) OPTIONAL MATCH (p)<-[:HAS_PHENOTYPE]-(d:Disease) OPTIONAL MATCH (s)-[:HAS_HYPOTHESIS]->(h:SpaceHypothesis)-[:TARGETS]->(d2:Disease) WITH s, collect(DISTINCT {hpo: p.hpoId, name: p.name}) AS phenos, collect(DISTINCT {orpha: d.orphaCode, name: d.name}) AS diseases, collect(DISTINCT {orpha: d2.orphaCode, name: d2.name}) AS hypos RETURN phenos, diseases, hypos """ rows = await _safe_query(cypher, {"sid": patient_id}) if not rows: return [] r = rows[0] triples = [] for p in (r.get("phenos") or [])[:limit]: if p.get("hpo"): triples.append([f"patient:{patient_id}", "HAS_PHENOTYPE", f"{p.get('name', '?')} ({p['hpo']})"]) for d in (r.get("diseases") or [])[:limit]: if d.get("orpha"): triples.append([f"disease:{d['orpha']}", "RELATED_TO_PATIENT", d.get("name") or d["orpha"]]) for h in (r.get("hypos") or [])[:limit]: if h.get("orpha"): triples.append([f"hypothesis", "TARGETS", f"{h.get('name', '?')} (ORPHA:{h['orpha']})"]) return triples async def _community_summary(query_terms: list[str], limit: int = 4) -> list: """Cluster matched nodes into "communities" by their primary disease group. Bootstrap version: groups Phenotype matches by their most-connected Disease, returns the disease name as the community label. """ if not query_terms: return [] cypher = """ MATCH (p:Phenotype)<-[:HAS_PHENOTYPE]-(d:Disease) WHERE any(term IN $terms WHERE toLower(p.name) CONTAINS toLower(term)) WITH d, collect(DISTINCT p.name)[..6] AS members, count(DISTINCT p) AS n ORDER BY n DESC LIMIT $limit RETURN d.name AS label, members, n """ rows = await _safe_query(cypher, {"terms": query_terms, "limit": int(limit)}) out = [] for r in rows: out.append({ "label": r.get("label"), "members": r.get("members") or [], "summary": f"{r.get('n', 0)} fenótipos compatíveis com {r.get('label')}", }) return out async def _cohort_exemplars(case_id: str, query_terms: list[str], limit: int = 3) -> list: """Pull a few cohort cases whose phenotypes overlap with the query terms.""" if not query_terms: return [] cypher = """ MATCH (other:PatientSpace)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(p:Phenotype) WHERE other.space_id <> $self AND any(term IN $terms WHERE toLower(p.name) CONTAINS toLower(term)) OPTIONAL MATCH (other)-[:HAS_HYPOTHESIS]->(h:SpaceHypothesis {status: 'confirmed'})-[:TARGETS]->(d:Disease) WITH other, d, collect(DISTINCT p.hpoId)[..6] AS shared RETURN other.space_id AS space_id, d.name AS diagnosis, shared AS shared_hpos LIMIT $limit """ return await _safe_query(cypher, { "self": case_id or "_", "terms": query_terms, "limit": int(limit), }) async def _literature(query_terms: list[str], limit: int = 3) -> list: """Top PubMed case reports overlapping the query terms.""" if not query_terms: return [] cypher = """ MATCH (p:Paper) WHERE any(term IN $terms WHERE toLower(p.title) CONTAINS toLower(term)) RETURN p.pmid AS pmid, p.title AS title, p.year AS year ORDER BY p.year DESC LIMIT $limit """ return await _safe_query(cypher, {"terms": query_terms, "limit": int(limit)}) def _split_terms(query: str) -> list[str]: # Naive: keep tokens with 4+ chars, drop stopwords. stop = {"the", "with", "para", "tem", "uma", "está", "como", "and", "for"} return [t for t in (w.strip(".,;:!?()[]\"'`") for w in (query or "").split()) if len(t) >= 4 and t.lower() not in stop][:8] async def retrieve( case_id: str, query: str, *, k: int = DEFAULT_K, mode: str = "local", ) -> dict: """Run GraphRAG retrieval grounded in the patient. Args: case_id: PatientSpace id query: free-text query (e.g. "is this disease linked to NPC1 mutations?") k: triples cap mode: "local" (subgraph only) or "global" (subgraph + cohort + literature) """ terms = _split_terms(query) triples = await _patient_triples(case_id, terms, limit=k) lex_hits = await _kg_lexical(query, limit=k) for h in lex_hits[:k]: if h.get("name"): triples.append([h.get("kind", "?"), "MATCHES_QUERY", f"{h['name']} ({h.get('code', '?')})"]) communities = await _community_summary(terms, limit=4) cohort_exemplars = [] literature = [] if mode == "global": cohort_exemplars = await _cohort_exemplars(case_id, terms, limit=3) literature = await _literature(terms, limit=3) return { "query": query, "case_id": case_id, "mode": mode, "triples": triples[: k * 2], "communities": communities, "cohort_exemplars": cohort_exemplars, "literature": literature, } def format_for_llm(result: dict) -> str: """Render a retrieval result as a compact Markdown block for the LLM.""" if not result: return "" lines = [f"### Gemeo lookup — `{result.get('query', '')}`"] if result.get("triples"): lines.append("**Triplas relevantes:**") for t in result["triples"][:12]: lines.append(f"- {t[0]} → {t[1]} → {t[2]}") if result.get("communities"): lines.append("\n**Clusters do KG:**") for c in result["communities"]: lines.append(f"- {c['label']}: {c['summary']}") if result.get("cohort_exemplars"): lines.append("\n**Exemplares na coorte:**") for e in result["cohort_exemplars"]: dx = e.get("diagnosis") or "sem dx" lines.append(f"- {e['space_id']}: {dx} (HPO em comum: {', '.join(e.get('shared_hpos', [])[:4])})") if result.get("literature"): lines.append("\n**Literatura:**") for l in result["literature"]: lines.append(f"- PMID:{l['pmid']} ({l.get('year', '?')}) — {l.get('title', '')[:100]}") return "\n".join(lines)