"""Drug-drug interaction prediction. Critical for case-driven workflow: rare-disease patients are frequently polymedicated (PCDT therapy + symptomatic + comorbidity meds), and interactions are a top cause of preventable harm. Strategy: 1. **KG walks** over Drug↔Drug↔Gene/Pathway/CYP edges in our enriched biomedical graph (DrugBank + DDInter + CPIC, indexed by raras-app). 2. **Severity classification** via interaction edges' attributes (severity ∈ {minor, moderate, major, contraindicated}). 3. **PK/PD mechanism narration** — extracted from interaction edge metadata; LLM-rewritten for clinician-friendly text. 4. **Phase-2 GNN** (gemeo/train/ddi_gnn.py) for unseen pairs — link prediction with mechanism-aware edge types. Returns a `DdiSpec` with a ranked list of pairwise predicted interactions plus a single overall `risk_level` for the regimen. """ from __future__ import annotations import logging import os from typing import Optional from .types import DdiSpec, DdiPair logger = logging.getLogger("gemeo.ddi") DDI_GNN_CKPT = os.environ.get( "GEMEO_DDI_CKPT", os.path.join(os.path.dirname(__file__), "artifacts", "ddi_gnn.pt"), ) async def _safe_query(cypher: str, params: dict = None) -> list: try: from space_graph import _safe_query as q return await q(cypher, params or {}, timeout=10.0) except Exception as e: logger.debug(f"cypher failed: {e}") return [] _SEVERITY_RANK = { "contraindicated": 4, "major": 3, "moderate": 2, "minor": 1, "unknown": 1, None: 1, } async def _kg_pairwise(drug_a: dict, drug_b: dict) -> Optional[dict]: """Look up a single Drug-Drug interaction edge in Neo4j.""" a_key = drug_a.get("rxcui") or drug_a.get("name") b_key = drug_b.get("rxcui") or drug_b.get("name") if not a_key or not b_key: return None cypher = """ MATCH (a:Drug)-[r:INTERACTS_WITH]-(b:Drug) WHERE (a.rxcui = $a OR toLower(a.name) = toLower($a)) AND (b.rxcui = $b OR toLower(b.name) = toLower($b)) RETURN r.severity AS severity, r.mechanism AS mechanism, r.evidence_level AS evidence_level, r.management AS management, r.references AS references, a.name AS a_name, b.name AS b_name LIMIT 1 """ rows = await _safe_query(cypher, {"a": a_key, "b": b_key}) return rows[0] if rows else None async def _kg_via_target(drug_a: dict, drug_b: dict) -> Optional[dict]: """Indirect interaction: shared CYP enzyme, transporter, or target.""" a_key = drug_a.get("rxcui") or drug_a.get("name") b_key = drug_b.get("rxcui") or drug_b.get("name") if not a_key or not b_key: return None cypher = """ MATCH (a:Drug)-[:METABOLIZED_BY|TARGETS|INHIBITS|INDUCES]->(g)<-[:METABOLIZED_BY|TARGETS|INHIBITS|INDUCES]-(b:Drug) WHERE (a.rxcui = $a OR toLower(a.name) = toLower($a)) AND (b.rxcui = $b OR toLower(b.name) = toLower($b)) AND a <> b RETURN g.symbol AS shared_target, labels(g)[0] AS target_kind, a.name AS a_name, b.name AS b_name LIMIT 1 """ rows = await _safe_query(cypher, {"a": a_key, "b": b_key}) if not rows: return None r = rows[0] return { "severity": "moderate", "mechanism": f"Shared {r.get('target_kind', 'target')}: {r.get('shared_target')}", "evidence_level": "indirect", "management": "Monitor for additive or competing effects.", "references": [], "a_name": r.get("a_name"), "b_name": r.get("b_name"), } async def _try_ddi_gnn(drug_pairs): if not os.path.exists(DDI_GNN_CKPT): return None try: import torch # noqa: F401 except ImportError: return None return None # phase-2 async def predict( *, medications: list, add_drug: dict = None, ) -> DdiSpec: """Predict drug-drug interactions across the regimen. Args: medications: list of {name, rxcui?} currently on the patient add_drug: optionally evaluate adding this drug (for what-if) """ drugs = list(medications or []) if add_drug: drugs = drugs + [add_drug] if len(drugs) < 2: return DdiSpec(pairs=[], n_pairs_evaluated=0, regimen_risk="none", model="kg_walks") pairs_out = [] n_evaluated = 0 for i in range(len(drugs)): for j in range(i + 1, len(drugs)): n_evaluated += 1 a, b = drugs[i], drugs[j] try: hit = await _kg_pairwise(a, b) if hit is None: hit = await _kg_via_target(a, b) except Exception as e: logger.debug(f"DDI lookup failed for ({a},{b}): {e}") continue if hit is None: continue pairs_out.append(DdiPair( drug_a=a.get("name") or a.get("rxcui"), drug_b=b.get("name") or b.get("rxcui"), rxcui_a=a.get("rxcui"), rxcui_b=b.get("rxcui"), severity=hit.get("severity") or "unknown", mechanism=hit.get("mechanism") or "", evidence_level=hit.get("evidence_level") or "kg", management=hit.get("management") or "", references=hit.get("references") or [], )) pairs_out.sort(key=lambda p: _SEVERITY_RANK.get(p.severity, 0), reverse=True) if not pairs_out: regimen_risk = "none" else: max_sev = pairs_out[0].severity regimen_risk = max_sev or "none" return DdiSpec( pairs=pairs_out, n_pairs_evaluated=n_evaluated, regimen_risk=regimen_risk, model="ddi_gnn" if os.path.exists(DDI_GNN_CKPT) else "kg_walks", )