gemeo-twin-stack / src /gemeo /subgraph.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""KG sparsification — extract the patient-specific reasoning subgraph.
Inspired by:
- "Knowledge Graph Sparsification for GNN-based Rare Disease Diagnosis"
(arXiv 2510.08655, Oct 2025)
- KARE (ICLR 2025) — KG community retrieval for reasoning
- MedGraphRAG (ACL 2025) — triple-graph for grounded medical QA
Output is a small (~30-200 nodes) graph centered on the patient that can be:
(a) rendered in the front-end with `react-force-graph-3d`
(b) fed to the LLM as a structured triple list ("Patient ─[HAS_PHENOTYPE]→ HP:X ←[ANNOTATES]─ Disease ORPHA:Y")
(c) used to extract narrated paths Patient→...→Disease
"""
from __future__ import annotations
import logging
from typing import Optional
from .types import Subgraph, SubgraphNode, SubgraphEdge
logger = logging.getLogger("gemeo.subgraph")
async def _safe_query(cypher: str, params: dict = None) -> list:
try:
from space_graph import _safe_query as q
return await q(cypher, params or {}, timeout=15.0)
except Exception as e:
logger.debug(f"cypher failed: {e}")
return []
async def extract(
*,
patient_id: str,
hpo_ids: list[str],
gene_symbols: list[str] = None,
target_orpha: str = None,
max_nodes: int = 80,
) -> Subgraph:
"""Extract reasoning subgraph for this patient.
If `target_orpha` is given, extract paths Patient→...→that disease.
Otherwise extract a 1-hop neighborhood centered on the patient's HPOs/genes
and the top diseases that share phenotypes with the patient.
"""
gene_symbols = gene_symbols or []
nodes: dict = {}
edges: list = []
# 1) the patient
pid = f"patient:{patient_id}"
nodes[pid] = SubgraphNode(
id=pid, label="Patient", name="Patient",
weight=1.0, extra={"is_center": True},
)
# 2) phenotypes
for hpo in hpo_ids[:30]:
nid = f"hpo:{hpo}"
# enrich with name
rows = await _safe_query(
"MATCH (p:Phenotype {hpoId: $hpo}) RETURN p.name AS name, p.definition AS def LIMIT 1",
{"hpo": hpo},
)
name = (rows[0]["name"] if rows else hpo)
nodes[nid] = SubgraphNode(id=nid, label="Phenotype", name=name, code=hpo, weight=0.9)
edges.append(SubgraphEdge(source=pid, target=nid, rel="HAS_PHENOTYPE", weight=1.0))
# 3) genes
for sym in gene_symbols[:10]:
nid = f"gene:{sym}"
nodes[nid] = SubgraphNode(id=nid, label="Gene", name=sym, code=sym, weight=0.9)
edges.append(SubgraphEdge(source=pid, target=nid, rel="HAS_GENE_VARIANT", weight=1.0))
# 4) candidate diseases
if target_orpha:
candidate_orphas = [target_orpha]
else:
if hpo_ids:
rows = await _safe_query(
"""
MATCH (p:Phenotype)<-[:HAS_PHENOTYPE]-(d:Disease)
WHERE p.hpoId IN $hpos
WITH d, count(p) AS overlap
ORDER BY overlap DESC
LIMIT 8
RETURN d.orphaCode AS orpha, d.name AS name, overlap
""",
{"hpos": hpo_ids[:30]},
)
candidate_orphas = [r["orpha"] for r in rows if r.get("orpha")]
else:
candidate_orphas = []
for orpha in candidate_orphas[:6]:
rows = await _safe_query(
"""
MATCH (d:Disease {orphaCode: $orpha})
OPTIONAL MATCH (d)-[:HAS_PHENOTYPE]->(p:Phenotype)
WHERE p.hpoId IN $hpos
OPTIONAL MATCH (d)-[:ASSOCIATED_WITH]->(g:Gene)
WHERE g.symbol IN $genes
RETURN d.name AS name,
d.cid10 AS cid10,
collect(DISTINCT p.hpoId) AS shared_hpos,
collect(DISTINCT g.symbol) AS shared_genes
""",
{"orpha": orpha, "hpos": hpo_ids[:30], "genes": gene_symbols[:10]},
)
if not rows:
continue
r = rows[0]
did = f"disease:{orpha}"
nodes[did] = SubgraphNode(
id=did, label="Disease", name=r.get("name") or orpha,
code=orpha, weight=0.95,
extra={"cid10": r.get("cid10")},
)
for hpo in (r.get("shared_hpos") or []):
hid = f"hpo:{hpo}"
if hid in nodes:
edges.append(SubgraphEdge(source=did, target=hid, rel="DISEASE_HAS_PHENOTYPE", weight=0.8))
for sym in (r.get("shared_genes") or []):
gid = f"gene:{sym}"
if gid in nodes:
edges.append(SubgraphEdge(source=did, target=gid, rel="ASSOCIATED_WITH", weight=0.85))
else:
# gene mentioned by disease but not in patient — still informative
gid = f"gene:{sym}"
nodes[gid] = SubgraphNode(id=gid, label="Gene", name=sym, code=sym, weight=0.6)
edges.append(SubgraphEdge(source=did, target=gid, rel="ASSOCIATED_WITH", weight=0.85))
# 5) optional: drugs targeting candidate diseases (1-hop)
if candidate_orphas:
rows = await _safe_query(
"""
MATCH (d:Disease)-[:TREATED_BY|TARGETED_BY]->(drug:Drug)
WHERE d.orphaCode IN $orphas
RETURN d.orphaCode AS orpha, drug.name AS name, drug.rxcui AS rxcui
LIMIT 20
""",
{"orphas": candidate_orphas},
)
for r in rows:
drug_name = r.get("name")
if not drug_name:
continue
did = f"disease:{r['orpha']}"
drug_id = f"drug:{r.get('rxcui') or drug_name}"
nodes[drug_id] = SubgraphNode(
id=drug_id, label="Drug", name=drug_name,
code=r.get("rxcui"), weight=0.7,
)
edges.append(SubgraphEdge(source=did, target=drug_id, rel="TREATED_BY", weight=0.7))
# cap node count by weight
if len(nodes) > max_nodes:
kept = sorted(nodes.values(), key=lambda n: n.weight, reverse=True)[:max_nodes]
kept_ids = {n.id for n in kept}
nodes = {n.id: n for n in kept}
edges = [e for e in edges if e.source in kept_ids and e.target in kept_ids]
# 6) build narrated paths Patient→...→Disease
paths = []
for orpha in candidate_orphas[:3]:
did = f"disease:{orpha}"
if did not in nodes:
continue
steps = []
steps.append({"node": pid, "rel": "is", "label": "Patient"})
# find a shared phenotype
for e in edges:
if e.source == pid and e.rel == "HAS_PHENOTYPE":
hpo_node = nodes.get(e.target)
if not hpo_node:
continue
# is this phenotype linked to the disease?
shared = any(
ee.source == did and ee.target == e.target
for ee in edges
)
if shared:
steps.append({"node": e.target, "rel": "HAS_PHENOTYPE", "label": hpo_node.name})
steps.append({"node": did, "rel": "DISEASE_HAS_PHENOTYPE_REVERSE", "label": nodes[did].name})
break
paths.append({
"target_orpha": orpha,
"target_name": nodes[did].name,
"steps": steps,
})
return Subgraph(
nodes=list(nodes.values()),
edges=edges,
paths=paths,
method="cypher_sparsify",
target_disease=target_orpha,
)