File size: 8,783 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""GraphRAG over the patient gêmeo.

Inspired by Microsoft GraphRAG (2024) + LightRAG (2024) + GFM-RAG (2025).
Adapted to the rare-disease patient setting.

Two retrieval modes:
  - **local** (default) — retrieve from the patient subgraph (HPO + Disease
    + Gene + Drug nodes already linked to this twin). Fast, low-cost.
  - **global** — also pull cohort exemplars and PubMed cases that share
    matching nodes; slower, broader.

The result is a compact JSON object the LLM can consume directly:
  {
    "query": "...",
    "triples": [["Patient", "HAS_PHENOTYPE", "HP:0001250 (Seizure)"], ...],
    "communities": [{"label": "Phenotype cluster A", "members": [...], "summary": "..."}],
    "cohort_exemplars": [{"space_id": "...", "diagnosis": "...", "shared_hpos": [...]}],
    "literature": [{"pmid": "...", "title": "...", "year": ...}]
  }
"""
from __future__ import annotations
import logging
from typing import Optional

logger = logging.getLogger("gemeo.graphrag")

DEFAULT_K = 12


async def _safe_query(cypher: str, params: dict = None) -> list:
    try:
        from space_graph import _safe_query as q
        return await q(cypher, params or {}, timeout=10.0)
    except Exception as e:
        logger.debug(f"cypher failed: {e}")
        return []


def _embed_query_cached(query: str):
    """BioLORD embedding for a free-text query.

    Returns a vector or None. Used for vector kNN over Disease/Phenotype
    indexes when available; falls back to lexical search otherwise.
    """
    try:
        from biolord_normalizer import normalize_one
        return normalize_one(query)
    except Exception:
        return None


async def _kg_lexical(query: str, limit: int = DEFAULT_K) -> list:
    """Lexical match (case-insensitive contains) — bootstrap when no embedding."""
    cypher = """
    MATCH (n)
    WHERE (n:Disease OR n:Phenotype OR n:Gene OR n:Drug)
      AND (toLower(n.name) CONTAINS toLower($q)
           OR toLower(coalesce(n.cid10DescriptionPt, '')) CONTAINS toLower($q))
    RETURN labels(n)[0] AS kind, n.name AS name,
           coalesce(n.orphaCode, n.hpoId, n.symbol, n.rxcui) AS code
    LIMIT $limit
    """
    return await _safe_query(cypher, {"q": query, "limit": int(limit)})


async def _patient_triples(patient_id: str, query_terms: list[str], limit: int = DEFAULT_K) -> list:
    """Pull triples from the patient's subgraph that are relevant to the query."""
    if not patient_id:
        return []
    # We pull the patient's connected entities and their relations to candidate diseases.
    cypher = """
    MATCH (s:PatientSpace {space_id: $sid})
    OPTIONAL MATCH (s)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(p:Phenotype)
    OPTIONAL MATCH (p)<-[:HAS_PHENOTYPE]-(d:Disease)
    OPTIONAL MATCH (s)-[:HAS_HYPOTHESIS]->(h:SpaceHypothesis)-[:TARGETS]->(d2:Disease)
    WITH s,
         collect(DISTINCT {hpo: p.hpoId, name: p.name}) AS phenos,
         collect(DISTINCT {orpha: d.orphaCode, name: d.name}) AS diseases,
         collect(DISTINCT {orpha: d2.orphaCode, name: d2.name}) AS hypos
    RETURN phenos, diseases, hypos
    """
    rows = await _safe_query(cypher, {"sid": patient_id})
    if not rows:
        return []
    r = rows[0]
    triples = []
    for p in (r.get("phenos") or [])[:limit]:
        if p.get("hpo"):
            triples.append([f"patient:{patient_id}", "HAS_PHENOTYPE", f"{p.get('name', '?')} ({p['hpo']})"])
    for d in (r.get("diseases") or [])[:limit]:
        if d.get("orpha"):
            triples.append([f"disease:{d['orpha']}", "RELATED_TO_PATIENT", d.get("name") or d["orpha"]])
    for h in (r.get("hypos") or [])[:limit]:
        if h.get("orpha"):
            triples.append([f"hypothesis", "TARGETS", f"{h.get('name', '?')} (ORPHA:{h['orpha']})"])
    return triples


async def _community_summary(query_terms: list[str], limit: int = 4) -> list:
    """Cluster matched nodes into "communities" by their primary disease group.

    Bootstrap version: groups Phenotype matches by their most-connected
    Disease, returns the disease name as the community label.
    """
    if not query_terms:
        return []
    cypher = """
    MATCH (p:Phenotype)<-[:HAS_PHENOTYPE]-(d:Disease)
    WHERE any(term IN $terms WHERE toLower(p.name) CONTAINS toLower(term))
    WITH d, collect(DISTINCT p.name)[..6] AS members, count(DISTINCT p) AS n
    ORDER BY n DESC
    LIMIT $limit
    RETURN d.name AS label, members, n
    """
    rows = await _safe_query(cypher, {"terms": query_terms, "limit": int(limit)})
    out = []
    for r in rows:
        out.append({
            "label": r.get("label"),
            "members": r.get("members") or [],
            "summary": f"{r.get('n', 0)} fenótipos compatíveis com {r.get('label')}",
        })
    return out


async def _cohort_exemplars(case_id: str, query_terms: list[str], limit: int = 3) -> list:
    """Pull a few cohort cases whose phenotypes overlap with the query terms."""
    if not query_terms:
        return []
    cypher = """
    MATCH (other:PatientSpace)-[:HAS_EVENT]->(:SpaceEvent)-[:INVOLVES]->(p:Phenotype)
    WHERE other.space_id <> $self
      AND any(term IN $terms WHERE toLower(p.name) CONTAINS toLower(term))
    OPTIONAL MATCH (other)-[:HAS_HYPOTHESIS]->(h:SpaceHypothesis {status: 'confirmed'})-[:TARGETS]->(d:Disease)
    WITH other, d, collect(DISTINCT p.hpoId)[..6] AS shared
    RETURN other.space_id AS space_id,
           d.name AS diagnosis,
           shared AS shared_hpos
    LIMIT $limit
    """
    return await _safe_query(cypher, {
        "self": case_id or "_",
        "terms": query_terms,
        "limit": int(limit),
    })


async def _literature(query_terms: list[str], limit: int = 3) -> list:
    """Top PubMed case reports overlapping the query terms."""
    if not query_terms:
        return []
    cypher = """
    MATCH (p:Paper)
    WHERE any(term IN $terms WHERE toLower(p.title) CONTAINS toLower(term))
    RETURN p.pmid AS pmid, p.title AS title, p.year AS year
    ORDER BY p.year DESC
    LIMIT $limit
    """
    return await _safe_query(cypher, {"terms": query_terms, "limit": int(limit)})


def _split_terms(query: str) -> list[str]:
    # Naive: keep tokens with 4+ chars, drop stopwords.
    stop = {"the", "with", "para", "tem", "uma", "está", "como", "and", "for"}
    return [t for t in (w.strip(".,;:!?()[]\"'`") for w in (query or "").split())
            if len(t) >= 4 and t.lower() not in stop][:8]


async def retrieve(
    case_id: str,
    query: str,
    *,
    k: int = DEFAULT_K,
    mode: str = "local",
) -> dict:
    """Run GraphRAG retrieval grounded in the patient.

    Args:
        case_id: PatientSpace id
        query: free-text query (e.g. "is this disease linked to NPC1 mutations?")
        k: triples cap
        mode: "local" (subgraph only) or "global" (subgraph + cohort + literature)
    """
    terms = _split_terms(query)

    triples = await _patient_triples(case_id, terms, limit=k)
    lex_hits = await _kg_lexical(query, limit=k)
    for h in lex_hits[:k]:
        if h.get("name"):
            triples.append([h.get("kind", "?"), "MATCHES_QUERY", f"{h['name']} ({h.get('code', '?')})"])

    communities = await _community_summary(terms, limit=4)

    cohort_exemplars = []
    literature = []
    if mode == "global":
        cohort_exemplars = await _cohort_exemplars(case_id, terms, limit=3)
        literature = await _literature(terms, limit=3)

    return {
        "query": query,
        "case_id": case_id,
        "mode": mode,
        "triples": triples[: k * 2],
        "communities": communities,
        "cohort_exemplars": cohort_exemplars,
        "literature": literature,
    }


def format_for_llm(result: dict) -> str:
    """Render a retrieval result as a compact Markdown block for the LLM."""
    if not result:
        return ""
    lines = [f"### Gemeo lookup — `{result.get('query', '')}`"]
    if result.get("triples"):
        lines.append("**Triplas relevantes:**")
        for t in result["triples"][:12]:
            lines.append(f"- {t[0]}{t[1]}{t[2]}")
    if result.get("communities"):
        lines.append("\n**Clusters do KG:**")
        for c in result["communities"]:
            lines.append(f"- {c['label']}: {c['summary']}")
    if result.get("cohort_exemplars"):
        lines.append("\n**Exemplares na coorte:**")
        for e in result["cohort_exemplars"]:
            dx = e.get("diagnosis") or "sem dx"
            lines.append(f"- {e['space_id']}: {dx} (HPO em comum: {', '.join(e.get('shared_hpos', [])[:4])})")
    if result.get("literature"):
        lines.append("\n**Literatura:**")
        for l in result["literature"]:
            lines.append(f"- PMID:{l['pmid']} ({l.get('year', '?')}) — {l.get('title', '')[:100]}")
    return "\n".join(lines)