| """KG sparsification — extract the patient-specific reasoning subgraph. |
| |
| Inspired by: |
| - "Knowledge Graph Sparsification for GNN-based Rare Disease Diagnosis" |
| (arXiv 2510.08655, Oct 2025) |
| - KARE (ICLR 2025) — KG community retrieval for reasoning |
| - MedGraphRAG (ACL 2025) — triple-graph for grounded medical QA |
| |
| Output is a small (~30-200 nodes) graph centered on the patient that can be: |
| (a) rendered in the front-end with `react-force-graph-3d` |
| (b) fed to the LLM as a structured triple list ("Patient ─[HAS_PHENOTYPE]→ HP:X ←[ANNOTATES]─ Disease ORPHA:Y") |
| (c) used to extract narrated paths Patient→...→Disease |
| """ |
| from __future__ import annotations |
| import logging |
| from typing import Optional |
|
|
| from .types import Subgraph, SubgraphNode, SubgraphEdge |
|
|
| logger = logging.getLogger("gemeo.subgraph") |
|
|
|
|
| async def _safe_query(cypher: str, params: dict = None) -> list: |
| try: |
| from space_graph import _safe_query as q |
| return await q(cypher, params or {}, timeout=15.0) |
| except Exception as e: |
| logger.debug(f"cypher failed: {e}") |
| return [] |
|
|
|
|
| async def extract( |
| *, |
| patient_id: str, |
| hpo_ids: list[str], |
| gene_symbols: list[str] = None, |
| target_orpha: str = None, |
| max_nodes: int = 80, |
| ) -> Subgraph: |
| """Extract reasoning subgraph for this patient. |
| |
| If `target_orpha` is given, extract paths Patient→...→that disease. |
| Otherwise extract a 1-hop neighborhood centered on the patient's HPOs/genes |
| and the top diseases that share phenotypes with the patient. |
| """ |
| gene_symbols = gene_symbols or [] |
| nodes: dict = {} |
| edges: list = [] |
|
|
| |
| pid = f"patient:{patient_id}" |
| nodes[pid] = SubgraphNode( |
| id=pid, label="Patient", name="Patient", |
| weight=1.0, extra={"is_center": True}, |
| ) |
|
|
| |
| for hpo in hpo_ids[:30]: |
| nid = f"hpo:{hpo}" |
| |
| rows = await _safe_query( |
| "MATCH (p:Phenotype {hpoId: $hpo}) RETURN p.name AS name, p.definition AS def LIMIT 1", |
| {"hpo": hpo}, |
| ) |
| name = (rows[0]["name"] if rows else hpo) |
| nodes[nid] = SubgraphNode(id=nid, label="Phenotype", name=name, code=hpo, weight=0.9) |
| edges.append(SubgraphEdge(source=pid, target=nid, rel="HAS_PHENOTYPE", weight=1.0)) |
|
|
| |
| for sym in gene_symbols[:10]: |
| nid = f"gene:{sym}" |
| nodes[nid] = SubgraphNode(id=nid, label="Gene", name=sym, code=sym, weight=0.9) |
| edges.append(SubgraphEdge(source=pid, target=nid, rel="HAS_GENE_VARIANT", weight=1.0)) |
|
|
| |
| if target_orpha: |
| candidate_orphas = [target_orpha] |
| else: |
| if hpo_ids: |
| rows = await _safe_query( |
| """ |
| MATCH (p:Phenotype)<-[:HAS_PHENOTYPE]-(d:Disease) |
| WHERE p.hpoId IN $hpos |
| WITH d, count(p) AS overlap |
| ORDER BY overlap DESC |
| LIMIT 8 |
| RETURN d.orphaCode AS orpha, d.name AS name, overlap |
| """, |
| {"hpos": hpo_ids[:30]}, |
| ) |
| candidate_orphas = [r["orpha"] for r in rows if r.get("orpha")] |
| else: |
| candidate_orphas = [] |
|
|
| for orpha in candidate_orphas[:6]: |
| rows = await _safe_query( |
| """ |
| MATCH (d:Disease {orphaCode: $orpha}) |
| OPTIONAL MATCH (d)-[:HAS_PHENOTYPE]->(p:Phenotype) |
| WHERE p.hpoId IN $hpos |
| OPTIONAL MATCH (d)-[:ASSOCIATED_WITH]->(g:Gene) |
| WHERE g.symbol IN $genes |
| RETURN d.name AS name, |
| d.cid10 AS cid10, |
| collect(DISTINCT p.hpoId) AS shared_hpos, |
| collect(DISTINCT g.symbol) AS shared_genes |
| """, |
| {"orpha": orpha, "hpos": hpo_ids[:30], "genes": gene_symbols[:10]}, |
| ) |
| if not rows: |
| continue |
| r = rows[0] |
| did = f"disease:{orpha}" |
| nodes[did] = SubgraphNode( |
| id=did, label="Disease", name=r.get("name") or orpha, |
| code=orpha, weight=0.95, |
| extra={"cid10": r.get("cid10")}, |
| ) |
| for hpo in (r.get("shared_hpos") or []): |
| hid = f"hpo:{hpo}" |
| if hid in nodes: |
| edges.append(SubgraphEdge(source=did, target=hid, rel="DISEASE_HAS_PHENOTYPE", weight=0.8)) |
| for sym in (r.get("shared_genes") or []): |
| gid = f"gene:{sym}" |
| if gid in nodes: |
| edges.append(SubgraphEdge(source=did, target=gid, rel="ASSOCIATED_WITH", weight=0.85)) |
| else: |
| |
| gid = f"gene:{sym}" |
| nodes[gid] = SubgraphNode(id=gid, label="Gene", name=sym, code=sym, weight=0.6) |
| edges.append(SubgraphEdge(source=did, target=gid, rel="ASSOCIATED_WITH", weight=0.85)) |
|
|
| |
| if candidate_orphas: |
| rows = await _safe_query( |
| """ |
| MATCH (d:Disease)-[:TREATED_BY|TARGETED_BY]->(drug:Drug) |
| WHERE d.orphaCode IN $orphas |
| RETURN d.orphaCode AS orpha, drug.name AS name, drug.rxcui AS rxcui |
| LIMIT 20 |
| """, |
| {"orphas": candidate_orphas}, |
| ) |
| for r in rows: |
| drug_name = r.get("name") |
| if not drug_name: |
| continue |
| did = f"disease:{r['orpha']}" |
| drug_id = f"drug:{r.get('rxcui') or drug_name}" |
| nodes[drug_id] = SubgraphNode( |
| id=drug_id, label="Drug", name=drug_name, |
| code=r.get("rxcui"), weight=0.7, |
| ) |
| edges.append(SubgraphEdge(source=did, target=drug_id, rel="TREATED_BY", weight=0.7)) |
|
|
| |
| if len(nodes) > max_nodes: |
| kept = sorted(nodes.values(), key=lambda n: n.weight, reverse=True)[:max_nodes] |
| kept_ids = {n.id for n in kept} |
| nodes = {n.id: n for n in kept} |
| edges = [e for e in edges if e.source in kept_ids and e.target in kept_ids] |
|
|
| |
| paths = [] |
| for orpha in candidate_orphas[:3]: |
| did = f"disease:{orpha}" |
| if did not in nodes: |
| continue |
| steps = [] |
| steps.append({"node": pid, "rel": "is", "label": "Patient"}) |
| |
| for e in edges: |
| if e.source == pid and e.rel == "HAS_PHENOTYPE": |
| hpo_node = nodes.get(e.target) |
| if not hpo_node: |
| continue |
| |
| shared = any( |
| ee.source == did and ee.target == e.target |
| for ee in edges |
| ) |
| if shared: |
| steps.append({"node": e.target, "rel": "HAS_PHENOTYPE", "label": hpo_node.name}) |
| steps.append({"node": did, "rel": "DISEASE_HAS_PHENOTYPE_REVERSE", "label": nodes[did].name}) |
| break |
| paths.append({ |
| "target_orpha": orpha, |
| "target_name": nodes[did].name, |
| "steps": steps, |
| }) |
|
|
| return Subgraph( |
| nodes=list(nodes.values()), |
| edges=edges, |
| paths=paths, |
| method="cypher_sparsify", |
| target_disease=target_orpha, |
| ) |
|
|