timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Drug-drug interaction prediction.
Critical for case-driven workflow: rare-disease patients are frequently
polymedicated (PCDT therapy + symptomatic + comorbidity meds), and
interactions are a top cause of preventable harm.
Strategy:
1. **KG walks** over Drug↔Drug↔Gene/Pathway/CYP edges in our enriched
biomedical graph (DrugBank + DDInter + CPIC, indexed by raras-app).
2. **Severity classification** via interaction edges' attributes
(severity ∈ {minor, moderate, major, contraindicated}).
3. **PK/PD mechanism narration** — extracted from interaction edge
metadata; LLM-rewritten for clinician-friendly text.
4. **Phase-2 GNN** (gemeo/train/ddi_gnn.py) for unseen pairs — link
prediction with mechanism-aware edge types.
Returns a `DdiSpec` with a ranked list of pairwise predicted interactions
plus a single overall `risk_level` for the regimen.
"""
from __future__ import annotations
import logging
import os
from typing import Optional
from .types import DdiSpec, DdiPair
logger = logging.getLogger("gemeo.ddi")
DDI_GNN_CKPT = os.environ.get(
"GEMEO_DDI_CKPT",
os.path.join(os.path.dirname(__file__), "artifacts", "ddi_gnn.pt"),
)
async def _safe_query(cypher: str, params: dict = None) -> list:
try:
from space_graph import _safe_query as q
return await q(cypher, params or {}, timeout=10.0)
except Exception as e:
logger.debug(f"cypher failed: {e}")
return []
_SEVERITY_RANK = {
"contraindicated": 4,
"major": 3,
"moderate": 2,
"minor": 1,
"unknown": 1,
None: 1,
}
async def _kg_pairwise(drug_a: dict, drug_b: dict) -> Optional[dict]:
"""Look up a single Drug-Drug interaction edge in Neo4j."""
a_key = drug_a.get("rxcui") or drug_a.get("name")
b_key = drug_b.get("rxcui") or drug_b.get("name")
if not a_key or not b_key:
return None
cypher = """
MATCH (a:Drug)-[r:INTERACTS_WITH]-(b:Drug)
WHERE (a.rxcui = $a OR toLower(a.name) = toLower($a))
AND (b.rxcui = $b OR toLower(b.name) = toLower($b))
RETURN r.severity AS severity,
r.mechanism AS mechanism,
r.evidence_level AS evidence_level,
r.management AS management,
r.references AS references,
a.name AS a_name, b.name AS b_name
LIMIT 1
"""
rows = await _safe_query(cypher, {"a": a_key, "b": b_key})
return rows[0] if rows else None
async def _kg_via_target(drug_a: dict, drug_b: dict) -> Optional[dict]:
"""Indirect interaction: shared CYP enzyme, transporter, or target."""
a_key = drug_a.get("rxcui") or drug_a.get("name")
b_key = drug_b.get("rxcui") or drug_b.get("name")
if not a_key or not b_key:
return None
cypher = """
MATCH (a:Drug)-[:METABOLIZED_BY|TARGETS|INHIBITS|INDUCES]->(g)<-[:METABOLIZED_BY|TARGETS|INHIBITS|INDUCES]-(b:Drug)
WHERE (a.rxcui = $a OR toLower(a.name) = toLower($a))
AND (b.rxcui = $b OR toLower(b.name) = toLower($b))
AND a <> b
RETURN g.symbol AS shared_target,
labels(g)[0] AS target_kind,
a.name AS a_name, b.name AS b_name
LIMIT 1
"""
rows = await _safe_query(cypher, {"a": a_key, "b": b_key})
if not rows:
return None
r = rows[0]
return {
"severity": "moderate",
"mechanism": f"Shared {r.get('target_kind', 'target')}: {r.get('shared_target')}",
"evidence_level": "indirect",
"management": "Monitor for additive or competing effects.",
"references": [],
"a_name": r.get("a_name"),
"b_name": r.get("b_name"),
}
async def _try_ddi_gnn(drug_pairs):
if not os.path.exists(DDI_GNN_CKPT):
return None
try:
import torch # noqa: F401
except ImportError:
return None
return None # phase-2
async def predict(
*,
medications: list,
add_drug: dict = None,
) -> DdiSpec:
"""Predict drug-drug interactions across the regimen.
Args:
medications: list of {name, rxcui?} currently on the patient
add_drug: optionally evaluate adding this drug (for what-if)
"""
drugs = list(medications or [])
if add_drug:
drugs = drugs + [add_drug]
if len(drugs) < 2:
return DdiSpec(pairs=[], n_pairs_evaluated=0, regimen_risk="none", model="kg_walks")
pairs_out = []
n_evaluated = 0
for i in range(len(drugs)):
for j in range(i + 1, len(drugs)):
n_evaluated += 1
a, b = drugs[i], drugs[j]
try:
hit = await _kg_pairwise(a, b)
if hit is None:
hit = await _kg_via_target(a, b)
except Exception as e:
logger.debug(f"DDI lookup failed for ({a},{b}): {e}")
continue
if hit is None:
continue
pairs_out.append(DdiPair(
drug_a=a.get("name") or a.get("rxcui"),
drug_b=b.get("name") or b.get("rxcui"),
rxcui_a=a.get("rxcui"),
rxcui_b=b.get("rxcui"),
severity=hit.get("severity") or "unknown",
mechanism=hit.get("mechanism") or "",
evidence_level=hit.get("evidence_level") or "kg",
management=hit.get("management") or "",
references=hit.get("references") or [],
))
pairs_out.sort(key=lambda p: _SEVERITY_RANK.get(p.severity, 0), reverse=True)
if not pairs_out:
regimen_risk = "none"
else:
max_sev = pairs_out[0].severity
regimen_risk = max_sev or "none"
return DdiSpec(
pairs=pairs_out,
n_pairs_evaluated=n_evaluated,
regimen_risk=regimen_risk,
model="ddi_gnn" if os.path.exists(DDI_GNN_CKPT) else "kg_walks",
)