| """Drug-drug interaction prediction. |
| |
| Critical for case-driven workflow: rare-disease patients are frequently |
| polymedicated (PCDT therapy + symptomatic + comorbidity meds), and |
| interactions are a top cause of preventable harm. |
| |
| Strategy: |
| 1. **KG walks** over Drug↔Drug↔Gene/Pathway/CYP edges in our enriched |
| biomedical graph (DrugBank + DDInter + CPIC, indexed by raras-app). |
| 2. **Severity classification** via interaction edges' attributes |
| (severity ∈ {minor, moderate, major, contraindicated}). |
| 3. **PK/PD mechanism narration** — extracted from interaction edge |
| metadata; LLM-rewritten for clinician-friendly text. |
| 4. **Phase-2 GNN** (gemeo/train/ddi_gnn.py) for unseen pairs — link |
| prediction with mechanism-aware edge types. |
| |
| Returns a `DdiSpec` with a ranked list of pairwise predicted interactions |
| plus a single overall `risk_level` for the regimen. |
| """ |
| from __future__ import annotations |
| import logging |
| import os |
| from typing import Optional |
|
|
| from .types import DdiSpec, DdiPair |
|
|
| logger = logging.getLogger("gemeo.ddi") |
|
|
| DDI_GNN_CKPT = os.environ.get( |
| "GEMEO_DDI_CKPT", |
| os.path.join(os.path.dirname(__file__), "artifacts", "ddi_gnn.pt"), |
| ) |
|
|
|
|
| async def _safe_query(cypher: str, params: dict = None) -> list: |
| try: |
| from space_graph import _safe_query as q |
| return await q(cypher, params or {}, timeout=10.0) |
| except Exception as e: |
| logger.debug(f"cypher failed: {e}") |
| return [] |
|
|
|
|
| _SEVERITY_RANK = { |
| "contraindicated": 4, |
| "major": 3, |
| "moderate": 2, |
| "minor": 1, |
| "unknown": 1, |
| None: 1, |
| } |
|
|
|
|
| async def _kg_pairwise(drug_a: dict, drug_b: dict) -> Optional[dict]: |
| """Look up a single Drug-Drug interaction edge in Neo4j.""" |
| a_key = drug_a.get("rxcui") or drug_a.get("name") |
| b_key = drug_b.get("rxcui") or drug_b.get("name") |
| if not a_key or not b_key: |
| return None |
| cypher = """ |
| MATCH (a:Drug)-[r:INTERACTS_WITH]-(b:Drug) |
| WHERE (a.rxcui = $a OR toLower(a.name) = toLower($a)) |
| AND (b.rxcui = $b OR toLower(b.name) = toLower($b)) |
| RETURN r.severity AS severity, |
| r.mechanism AS mechanism, |
| r.evidence_level AS evidence_level, |
| r.management AS management, |
| r.references AS references, |
| a.name AS a_name, b.name AS b_name |
| LIMIT 1 |
| """ |
| rows = await _safe_query(cypher, {"a": a_key, "b": b_key}) |
| return rows[0] if rows else None |
|
|
|
|
| async def _kg_via_target(drug_a: dict, drug_b: dict) -> Optional[dict]: |
| """Indirect interaction: shared CYP enzyme, transporter, or target.""" |
| a_key = drug_a.get("rxcui") or drug_a.get("name") |
| b_key = drug_b.get("rxcui") or drug_b.get("name") |
| if not a_key or not b_key: |
| return None |
| cypher = """ |
| MATCH (a:Drug)-[:METABOLIZED_BY|TARGETS|INHIBITS|INDUCES]->(g)<-[:METABOLIZED_BY|TARGETS|INHIBITS|INDUCES]-(b:Drug) |
| WHERE (a.rxcui = $a OR toLower(a.name) = toLower($a)) |
| AND (b.rxcui = $b OR toLower(b.name) = toLower($b)) |
| AND a <> b |
| RETURN g.symbol AS shared_target, |
| labels(g)[0] AS target_kind, |
| a.name AS a_name, b.name AS b_name |
| LIMIT 1 |
| """ |
| rows = await _safe_query(cypher, {"a": a_key, "b": b_key}) |
| if not rows: |
| return None |
| r = rows[0] |
| return { |
| "severity": "moderate", |
| "mechanism": f"Shared {r.get('target_kind', 'target')}: {r.get('shared_target')}", |
| "evidence_level": "indirect", |
| "management": "Monitor for additive or competing effects.", |
| "references": [], |
| "a_name": r.get("a_name"), |
| "b_name": r.get("b_name"), |
| } |
|
|
|
|
| async def _try_ddi_gnn(drug_pairs): |
| if not os.path.exists(DDI_GNN_CKPT): |
| return None |
| try: |
| import torch |
| except ImportError: |
| return None |
| return None |
|
|
|
|
| async def predict( |
| *, |
| medications: list, |
| add_drug: dict = None, |
| ) -> DdiSpec: |
| """Predict drug-drug interactions across the regimen. |
| |
| Args: |
| medications: list of {name, rxcui?} currently on the patient |
| add_drug: optionally evaluate adding this drug (for what-if) |
| """ |
| drugs = list(medications or []) |
| if add_drug: |
| drugs = drugs + [add_drug] |
|
|
| if len(drugs) < 2: |
| return DdiSpec(pairs=[], n_pairs_evaluated=0, regimen_risk="none", model="kg_walks") |
|
|
| pairs_out = [] |
| n_evaluated = 0 |
| for i in range(len(drugs)): |
| for j in range(i + 1, len(drugs)): |
| n_evaluated += 1 |
| a, b = drugs[i], drugs[j] |
| try: |
| hit = await _kg_pairwise(a, b) |
| if hit is None: |
| hit = await _kg_via_target(a, b) |
| except Exception as e: |
| logger.debug(f"DDI lookup failed for ({a},{b}): {e}") |
| continue |
| if hit is None: |
| continue |
| pairs_out.append(DdiPair( |
| drug_a=a.get("name") or a.get("rxcui"), |
| drug_b=b.get("name") or b.get("rxcui"), |
| rxcui_a=a.get("rxcui"), |
| rxcui_b=b.get("rxcui"), |
| severity=hit.get("severity") or "unknown", |
| mechanism=hit.get("mechanism") or "", |
| evidence_level=hit.get("evidence_level") or "kg", |
| management=hit.get("management") or "", |
| references=hit.get("references") or [], |
| )) |
|
|
| pairs_out.sort(key=lambda p: _SEVERITY_RANK.get(p.severity, 0), reverse=True) |
|
|
| if not pairs_out: |
| regimen_risk = "none" |
| else: |
| max_sev = pairs_out[0].severity |
| regimen_risk = max_sev or "none" |
|
|
| return DdiSpec( |
| pairs=pairs_out, |
| n_pairs_evaluated=n_evaluated, |
| regimen_risk=regimen_risk, |
| model="ddi_gnn" if os.path.exists(DDI_GNN_CKPT) else "kg_walks", |
| ) |
|
|