"""KG sparsification — extract the patient-specific reasoning subgraph. Inspired by: - "Knowledge Graph Sparsification for GNN-based Rare Disease Diagnosis" (arXiv 2510.08655, Oct 2025) - KARE (ICLR 2025) — KG community retrieval for reasoning - MedGraphRAG (ACL 2025) — triple-graph for grounded medical QA Output is a small (~30-200 nodes) graph centered on the patient that can be: (a) rendered in the front-end with `react-force-graph-3d` (b) fed to the LLM as a structured triple list ("Patient ─[HAS_PHENOTYPE]→ HP:X ←[ANNOTATES]─ Disease ORPHA:Y") (c) used to extract narrated paths Patient→...→Disease """ from __future__ import annotations import logging from typing import Optional from .types import Subgraph, SubgraphNode, SubgraphEdge logger = logging.getLogger("gemeo.subgraph") async def _safe_query(cypher: str, params: dict = None) -> list: try: from space_graph import _safe_query as q return await q(cypher, params or {}, timeout=15.0) except Exception as e: logger.debug(f"cypher failed: {e}") return [] async def extract( *, patient_id: str, hpo_ids: list[str], gene_symbols: list[str] = None, target_orpha: str = None, max_nodes: int = 80, ) -> Subgraph: """Extract reasoning subgraph for this patient. If `target_orpha` is given, extract paths Patient→...→that disease. Otherwise extract a 1-hop neighborhood centered on the patient's HPOs/genes and the top diseases that share phenotypes with the patient. """ gene_symbols = gene_symbols or [] nodes: dict = {} edges: list = [] # 1) the patient pid = f"patient:{patient_id}" nodes[pid] = SubgraphNode( id=pid, label="Patient", name="Patient", weight=1.0, extra={"is_center": True}, ) # 2) phenotypes for hpo in hpo_ids[:30]: nid = f"hpo:{hpo}" # enrich with name rows = await _safe_query( "MATCH (p:Phenotype {hpoId: $hpo}) RETURN p.name AS name, p.definition AS def LIMIT 1", {"hpo": hpo}, ) name = (rows[0]["name"] if rows else hpo) nodes[nid] = SubgraphNode(id=nid, label="Phenotype", name=name, code=hpo, weight=0.9) edges.append(SubgraphEdge(source=pid, target=nid, rel="HAS_PHENOTYPE", weight=1.0)) # 3) genes for sym in gene_symbols[:10]: nid = f"gene:{sym}" nodes[nid] = SubgraphNode(id=nid, label="Gene", name=sym, code=sym, weight=0.9) edges.append(SubgraphEdge(source=pid, target=nid, rel="HAS_GENE_VARIANT", weight=1.0)) # 4) candidate diseases if target_orpha: candidate_orphas = [target_orpha] else: if hpo_ids: rows = await _safe_query( """ MATCH (p:Phenotype)<-[:HAS_PHENOTYPE]-(d:Disease) WHERE p.hpoId IN $hpos WITH d, count(p) AS overlap ORDER BY overlap DESC LIMIT 8 RETURN d.orphaCode AS orpha, d.name AS name, overlap """, {"hpos": hpo_ids[:30]}, ) candidate_orphas = [r["orpha"] for r in rows if r.get("orpha")] else: candidate_orphas = [] for orpha in candidate_orphas[:6]: rows = await _safe_query( """ MATCH (d:Disease {orphaCode: $orpha}) OPTIONAL MATCH (d)-[:HAS_PHENOTYPE]->(p:Phenotype) WHERE p.hpoId IN $hpos OPTIONAL MATCH (d)-[:ASSOCIATED_WITH]->(g:Gene) WHERE g.symbol IN $genes RETURN d.name AS name, d.cid10 AS cid10, collect(DISTINCT p.hpoId) AS shared_hpos, collect(DISTINCT g.symbol) AS shared_genes """, {"orpha": orpha, "hpos": hpo_ids[:30], "genes": gene_symbols[:10]}, ) if not rows: continue r = rows[0] did = f"disease:{orpha}" nodes[did] = SubgraphNode( id=did, label="Disease", name=r.get("name") or orpha, code=orpha, weight=0.95, extra={"cid10": r.get("cid10")}, ) for hpo in (r.get("shared_hpos") or []): hid = f"hpo:{hpo}" if hid in nodes: edges.append(SubgraphEdge(source=did, target=hid, rel="DISEASE_HAS_PHENOTYPE", weight=0.8)) for sym in (r.get("shared_genes") or []): gid = f"gene:{sym}" if gid in nodes: edges.append(SubgraphEdge(source=did, target=gid, rel="ASSOCIATED_WITH", weight=0.85)) else: # gene mentioned by disease but not in patient — still informative gid = f"gene:{sym}" nodes[gid] = SubgraphNode(id=gid, label="Gene", name=sym, code=sym, weight=0.6) edges.append(SubgraphEdge(source=did, target=gid, rel="ASSOCIATED_WITH", weight=0.85)) # 5) optional: drugs targeting candidate diseases (1-hop) if candidate_orphas: rows = await _safe_query( """ MATCH (d:Disease)-[:TREATED_BY|TARGETED_BY]->(drug:Drug) WHERE d.orphaCode IN $orphas RETURN d.orphaCode AS orpha, drug.name AS name, drug.rxcui AS rxcui LIMIT 20 """, {"orphas": candidate_orphas}, ) for r in rows: drug_name = r.get("name") if not drug_name: continue did = f"disease:{r['orpha']}" drug_id = f"drug:{r.get('rxcui') or drug_name}" nodes[drug_id] = SubgraphNode( id=drug_id, label="Drug", name=drug_name, code=r.get("rxcui"), weight=0.7, ) edges.append(SubgraphEdge(source=did, target=drug_id, rel="TREATED_BY", weight=0.7)) # cap node count by weight if len(nodes) > max_nodes: kept = sorted(nodes.values(), key=lambda n: n.weight, reverse=True)[:max_nodes] kept_ids = {n.id for n in kept} nodes = {n.id: n for n in kept} edges = [e for e in edges if e.source in kept_ids and e.target in kept_ids] # 6) build narrated paths Patient→...→Disease paths = [] for orpha in candidate_orphas[:3]: did = f"disease:{orpha}" if did not in nodes: continue steps = [] steps.append({"node": pid, "rel": "is", "label": "Patient"}) # find a shared phenotype for e in edges: if e.source == pid and e.rel == "HAS_PHENOTYPE": hpo_node = nodes.get(e.target) if not hpo_node: continue # is this phenotype linked to the disease? shared = any( ee.source == did and ee.target == e.target for ee in edges ) if shared: steps.append({"node": e.target, "rel": "HAS_PHENOTYPE", "label": hpo_node.name}) steps.append({"node": did, "rel": "DISEASE_HAS_PHENOTYPE_REVERSE", "label": nodes[did].name}) break paths.append({ "target_orpha": orpha, "target_name": nodes[did].name, "steps": steps, }) return Subgraph( nodes=list(nodes.values()), edges=edges, paths=paths, method="cypher_sparsify", target_disease=target_orpha, )