File size: 7,474 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""KG sparsification — extract the patient-specific reasoning subgraph.

Inspired by:
  - "Knowledge Graph Sparsification for GNN-based Rare Disease Diagnosis"
    (arXiv 2510.08655, Oct 2025)
  - KARE (ICLR 2025) — KG community retrieval for reasoning
  - MedGraphRAG (ACL 2025) — triple-graph for grounded medical QA

Output is a small (~30-200 nodes) graph centered on the patient that can be:
  (a) rendered in the front-end with `react-force-graph-3d`
  (b) fed to the LLM as a structured triple list ("Patient ─[HAS_PHENOTYPE]→ HP:X ←[ANNOTATES]─ Disease ORPHA:Y")
  (c) used to extract narrated paths Patient→...→Disease
"""
from __future__ import annotations
import logging
from typing import Optional

from .types import Subgraph, SubgraphNode, SubgraphEdge

logger = logging.getLogger("gemeo.subgraph")


async def _safe_query(cypher: str, params: dict = None) -> list:
    try:
        from space_graph import _safe_query as q
        return await q(cypher, params or {}, timeout=15.0)
    except Exception as e:
        logger.debug(f"cypher failed: {e}")
        return []


async def extract(
    *,
    patient_id: str,
    hpo_ids: list[str],
    gene_symbols: list[str] = None,
    target_orpha: str = None,
    max_nodes: int = 80,
) -> Subgraph:
    """Extract reasoning subgraph for this patient.

    If `target_orpha` is given, extract paths Patient→...→that disease.
    Otherwise extract a 1-hop neighborhood centered on the patient's HPOs/genes
    and the top diseases that share phenotypes with the patient.
    """
    gene_symbols = gene_symbols or []
    nodes: dict = {}
    edges: list = []

    # 1) the patient
    pid = f"patient:{patient_id}"
    nodes[pid] = SubgraphNode(
        id=pid, label="Patient", name="Patient",
        weight=1.0, extra={"is_center": True},
    )

    # 2) phenotypes
    for hpo in hpo_ids[:30]:
        nid = f"hpo:{hpo}"
        # enrich with name
        rows = await _safe_query(
            "MATCH (p:Phenotype {hpoId: $hpo}) RETURN p.name AS name, p.definition AS def LIMIT 1",
            {"hpo": hpo},
        )
        name = (rows[0]["name"] if rows else hpo)
        nodes[nid] = SubgraphNode(id=nid, label="Phenotype", name=name, code=hpo, weight=0.9)
        edges.append(SubgraphEdge(source=pid, target=nid, rel="HAS_PHENOTYPE", weight=1.0))

    # 3) genes
    for sym in gene_symbols[:10]:
        nid = f"gene:{sym}"
        nodes[nid] = SubgraphNode(id=nid, label="Gene", name=sym, code=sym, weight=0.9)
        edges.append(SubgraphEdge(source=pid, target=nid, rel="HAS_GENE_VARIANT", weight=1.0))

    # 4) candidate diseases
    if target_orpha:
        candidate_orphas = [target_orpha]
    else:
        if hpo_ids:
            rows = await _safe_query(
                """
                MATCH (p:Phenotype)<-[:HAS_PHENOTYPE]-(d:Disease)
                WHERE p.hpoId IN $hpos
                WITH d, count(p) AS overlap
                ORDER BY overlap DESC
                LIMIT 8
                RETURN d.orphaCode AS orpha, d.name AS name, overlap
                """,
                {"hpos": hpo_ids[:30]},
            )
            candidate_orphas = [r["orpha"] for r in rows if r.get("orpha")]
        else:
            candidate_orphas = []

    for orpha in candidate_orphas[:6]:
        rows = await _safe_query(
            """
            MATCH (d:Disease {orphaCode: $orpha})
            OPTIONAL MATCH (d)-[:HAS_PHENOTYPE]->(p:Phenotype)
              WHERE p.hpoId IN $hpos
            OPTIONAL MATCH (d)-[:ASSOCIATED_WITH]->(g:Gene)
              WHERE g.symbol IN $genes
            RETURN d.name AS name,
                   d.cid10 AS cid10,
                   collect(DISTINCT p.hpoId) AS shared_hpos,
                   collect(DISTINCT g.symbol) AS shared_genes
            """,
            {"orpha": orpha, "hpos": hpo_ids[:30], "genes": gene_symbols[:10]},
        )
        if not rows:
            continue
        r = rows[0]
        did = f"disease:{orpha}"
        nodes[did] = SubgraphNode(
            id=did, label="Disease", name=r.get("name") or orpha,
            code=orpha, weight=0.95,
            extra={"cid10": r.get("cid10")},
        )
        for hpo in (r.get("shared_hpos") or []):
            hid = f"hpo:{hpo}"
            if hid in nodes:
                edges.append(SubgraphEdge(source=did, target=hid, rel="DISEASE_HAS_PHENOTYPE", weight=0.8))
        for sym in (r.get("shared_genes") or []):
            gid = f"gene:{sym}"
            if gid in nodes:
                edges.append(SubgraphEdge(source=did, target=gid, rel="ASSOCIATED_WITH", weight=0.85))
            else:
                # gene mentioned by disease but not in patient — still informative
                gid = f"gene:{sym}"
                nodes[gid] = SubgraphNode(id=gid, label="Gene", name=sym, code=sym, weight=0.6)
                edges.append(SubgraphEdge(source=did, target=gid, rel="ASSOCIATED_WITH", weight=0.85))

    # 5) optional: drugs targeting candidate diseases (1-hop)
    if candidate_orphas:
        rows = await _safe_query(
            """
            MATCH (d:Disease)-[:TREATED_BY|TARGETED_BY]->(drug:Drug)
            WHERE d.orphaCode IN $orphas
            RETURN d.orphaCode AS orpha, drug.name AS name, drug.rxcui AS rxcui
            LIMIT 20
            """,
            {"orphas": candidate_orphas},
        )
        for r in rows:
            drug_name = r.get("name")
            if not drug_name:
                continue
            did = f"disease:{r['orpha']}"
            drug_id = f"drug:{r.get('rxcui') or drug_name}"
            nodes[drug_id] = SubgraphNode(
                id=drug_id, label="Drug", name=drug_name,
                code=r.get("rxcui"), weight=0.7,
            )
            edges.append(SubgraphEdge(source=did, target=drug_id, rel="TREATED_BY", weight=0.7))

    # cap node count by weight
    if len(nodes) > max_nodes:
        kept = sorted(nodes.values(), key=lambda n: n.weight, reverse=True)[:max_nodes]
        kept_ids = {n.id for n in kept}
        nodes = {n.id: n for n in kept}
        edges = [e for e in edges if e.source in kept_ids and e.target in kept_ids]

    # 6) build narrated paths Patient→...→Disease
    paths = []
    for orpha in candidate_orphas[:3]:
        did = f"disease:{orpha}"
        if did not in nodes:
            continue
        steps = []
        steps.append({"node": pid, "rel": "is", "label": "Patient"})
        # find a shared phenotype
        for e in edges:
            if e.source == pid and e.rel == "HAS_PHENOTYPE":
                hpo_node = nodes.get(e.target)
                if not hpo_node:
                    continue
                # is this phenotype linked to the disease?
                shared = any(
                    ee.source == did and ee.target == e.target
                    for ee in edges
                )
                if shared:
                    steps.append({"node": e.target, "rel": "HAS_PHENOTYPE", "label": hpo_node.name})
                    steps.append({"node": did, "rel": "DISEASE_HAS_PHENOTYPE_REVERSE", "label": nodes[did].name})
                    break
        paths.append({
            "target_orpha": orpha,
            "target_name": nodes[did].name,
            "steps": steps,
        })

    return Subgraph(
        nodes=list(nodes.values()),
        edges=edges,
        paths=paths,
        method="cypher_sparsify",
        target_disease=target_orpha,
    )