File size: 5,821 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""Drug-drug interaction prediction.

Critical for case-driven workflow: rare-disease patients are frequently
polymedicated (PCDT therapy + symptomatic + comorbidity meds), and
interactions are a top cause of preventable harm.

Strategy:
  1. **KG walks** over Drug↔Drug↔Gene/Pathway/CYP edges in our enriched
     biomedical graph (DrugBank + DDInter + CPIC, indexed by raras-app).
  2. **Severity classification** via interaction edges' attributes
     (severity ∈ {minor, moderate, major, contraindicated}).
  3. **PK/PD mechanism narration** — extracted from interaction edge
     metadata; LLM-rewritten for clinician-friendly text.
  4. **Phase-2 GNN** (gemeo/train/ddi_gnn.py) for unseen pairs — link
     prediction with mechanism-aware edge types.

Returns a `DdiSpec` with a ranked list of pairwise predicted interactions
plus a single overall `risk_level` for the regimen.
"""
from __future__ import annotations
import logging
import os
from typing import Optional

from .types import DdiSpec, DdiPair

logger = logging.getLogger("gemeo.ddi")

DDI_GNN_CKPT = os.environ.get(
    "GEMEO_DDI_CKPT",
    os.path.join(os.path.dirname(__file__), "artifacts", "ddi_gnn.pt"),
)


async def _safe_query(cypher: str, params: dict = None) -> list:
    try:
        from space_graph import _safe_query as q
        return await q(cypher, params or {}, timeout=10.0)
    except Exception as e:
        logger.debug(f"cypher failed: {e}")
        return []


_SEVERITY_RANK = {
    "contraindicated": 4,
    "major": 3,
    "moderate": 2,
    "minor": 1,
    "unknown": 1,
    None: 1,
}


async def _kg_pairwise(drug_a: dict, drug_b: dict) -> Optional[dict]:
    """Look up a single Drug-Drug interaction edge in Neo4j."""
    a_key = drug_a.get("rxcui") or drug_a.get("name")
    b_key = drug_b.get("rxcui") or drug_b.get("name")
    if not a_key or not b_key:
        return None
    cypher = """
    MATCH (a:Drug)-[r:INTERACTS_WITH]-(b:Drug)
    WHERE (a.rxcui = $a OR toLower(a.name) = toLower($a))
      AND (b.rxcui = $b OR toLower(b.name) = toLower($b))
    RETURN r.severity AS severity,
           r.mechanism AS mechanism,
           r.evidence_level AS evidence_level,
           r.management AS management,
           r.references AS references,
           a.name AS a_name, b.name AS b_name
    LIMIT 1
    """
    rows = await _safe_query(cypher, {"a": a_key, "b": b_key})
    return rows[0] if rows else None


async def _kg_via_target(drug_a: dict, drug_b: dict) -> Optional[dict]:
    """Indirect interaction: shared CYP enzyme, transporter, or target."""
    a_key = drug_a.get("rxcui") or drug_a.get("name")
    b_key = drug_b.get("rxcui") or drug_b.get("name")
    if not a_key or not b_key:
        return None
    cypher = """
    MATCH (a:Drug)-[:METABOLIZED_BY|TARGETS|INHIBITS|INDUCES]->(g)<-[:METABOLIZED_BY|TARGETS|INHIBITS|INDUCES]-(b:Drug)
    WHERE (a.rxcui = $a OR toLower(a.name) = toLower($a))
      AND (b.rxcui = $b OR toLower(b.name) = toLower($b))
      AND a <> b
    RETURN g.symbol AS shared_target,
           labels(g)[0] AS target_kind,
           a.name AS a_name, b.name AS b_name
    LIMIT 1
    """
    rows = await _safe_query(cypher, {"a": a_key, "b": b_key})
    if not rows:
        return None
    r = rows[0]
    return {
        "severity": "moderate",
        "mechanism": f"Shared {r.get('target_kind', 'target')}: {r.get('shared_target')}",
        "evidence_level": "indirect",
        "management": "Monitor for additive or competing effects.",
        "references": [],
        "a_name": r.get("a_name"),
        "b_name": r.get("b_name"),
    }


async def _try_ddi_gnn(drug_pairs):
    if not os.path.exists(DDI_GNN_CKPT):
        return None
    try:
        import torch  # noqa: F401
    except ImportError:
        return None
    return None  # phase-2


async def predict(
    *,
    medications: list,
    add_drug: dict = None,
) -> DdiSpec:
    """Predict drug-drug interactions across the regimen.

    Args:
        medications: list of {name, rxcui?} currently on the patient
        add_drug:    optionally evaluate adding this drug (for what-if)
    """
    drugs = list(medications or [])
    if add_drug:
        drugs = drugs + [add_drug]

    if len(drugs) < 2:
        return DdiSpec(pairs=[], n_pairs_evaluated=0, regimen_risk="none", model="kg_walks")

    pairs_out = []
    n_evaluated = 0
    for i in range(len(drugs)):
        for j in range(i + 1, len(drugs)):
            n_evaluated += 1
            a, b = drugs[i], drugs[j]
            try:
                hit = await _kg_pairwise(a, b)
                if hit is None:
                    hit = await _kg_via_target(a, b)
            except Exception as e:
                logger.debug(f"DDI lookup failed for ({a},{b}): {e}")
                continue
            if hit is None:
                continue
            pairs_out.append(DdiPair(
                drug_a=a.get("name") or a.get("rxcui"),
                drug_b=b.get("name") or b.get("rxcui"),
                rxcui_a=a.get("rxcui"),
                rxcui_b=b.get("rxcui"),
                severity=hit.get("severity") or "unknown",
                mechanism=hit.get("mechanism") or "",
                evidence_level=hit.get("evidence_level") or "kg",
                management=hit.get("management") or "",
                references=hit.get("references") or [],
            ))

    pairs_out.sort(key=lambda p: _SEVERITY_RANK.get(p.severity, 0), reverse=True)

    if not pairs_out:
        regimen_risk = "none"
    else:
        max_sev = pairs_out[0].severity
        regimen_risk = max_sev or "none"

    return DdiSpec(
        pairs=pairs_out,
        n_pairs_evaluated=n_evaluated,
        regimen_risk=regimen_risk,
        model="ddi_gnn" if os.path.exists(DDI_GNN_CKPT) else "kg_walks",
    )