File size: 4,951 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Drug repurposing — TxGNN-style link prediction Disease ↔ Drug.

Bootstrap: wraps `drug_repurposer.find_repurposing_candidates` (KG walks
Disease → Gene → Drug via Cypher).

Phase 2: TxGNN fine-tune on PrimeKG + our enriched KG (gemeo/train/txgnn.py).
When checkpoint exists, overrides the bootstrap path.

We also enrich each candidate with SUS dispensation status — whether
the drug is in any PCDT and whether it's available in the patient's UF.
"""
from __future__ import annotations
import logging
import os
from typing import Optional

from .types import DrugSpec

logger = logging.getLogger("gemeo.repurpose")

TXGNN_CKPT = os.environ.get(
    "GEMEO_TXGNN_CKPT",
    os.path.join(os.path.dirname(__file__), "artifacts", "txgnn.pt"),
)


async def _try_txgnn(space, embedding):
    if not os.path.exists(TXGNN_CKPT):
        return None
    try:
        from .train import txgnn as tx_mod
        return await tx_mod.predict(space, embedding, TXGNN_CKPT)
    except Exception as e:
        logger.warning(f"TxGNN predict failed: {e}")
        return None


def _enrich_with_sus(candidates: list, sus_region: Optional[str]) -> list:
    """Mark each candidate with PCDT availability + UF dispensation."""
    try:
        from brazilian_context import get_pcdt
    except ImportError:
        return candidates

    out = []
    for c in candidates:
        item = dict(c) if isinstance(c, dict) else {
            "name": getattr(c, "name", None),
            "rxcui": getattr(c, "rxcui", None),
            "score": getattr(c, "score", None),
            "mechanism": getattr(c, "mechanism", None),
            "status": getattr(c, "status", None),
            "for_disease": getattr(c, "for_disease", None),
            "for_orpha": getattr(c, "for_orpha", None),
        }
        sus_avail = False
        in_pcdt = False
        if item.get("for_orpha"):
            try:
                pcdt = get_pcdt(item["for_orpha"])
            except Exception:
                pcdt = None
            if pcdt:
                in_pcdt = True
                # heuristic: drug name appears in PCDT therapy list
                therapies = (pcdt.get("therapies") or []) + (pcdt.get("medicamentos") or [])
                drug_name_lower = (item.get("name") or "").lower()
                if drug_name_lower and any(drug_name_lower in (str(t) or "").lower() for t in therapies):
                    sus_avail = True
        item["sus_in_pcdt"] = in_pcdt
        item["sus_dispensed"] = sus_avail  # TODO: cross-ref APAC by UF when available
        item["sus_uf"] = sus_region
        out.append(item)
    return out


async def find(space, embedding=None, sus_region: Optional[str] = None) -> DrugSpec:
    """Find drug candidates for the digital twin."""

    # try TxGNN first
    spec = await _try_txgnn(space, embedding)
    if spec is not None:
        spec.candidates = _enrich_with_sus(spec.candidates, sus_region)
        return spec

    candidates = []
    n_eval = 0
    try:
        from drug_repurposer import find_repurposing_candidates
        result = await find_repurposing_candidates(space)
        if result is None:
            pass
        elif isinstance(result, dict):
            candidates = result.get("candidates", []) or []
            n_eval = int(result.get("n_evaluated", len(candidates)))
        else:
            candidates = getattr(result, "candidates", []) or []
            n_eval = int(getattr(result, "n_evaluated", len(candidates)) or len(candidates))
    except Exception as e:
        logger.debug(f"drug_repurposer failed: {e}")

    # serialize candidate objects
    # The upstream `drug_repurposer.DrugCandidate` dataclass uses field
    # names `drug_name`/`drug_id`/`evidence_level`, so we map them into
    # the gemeo.types.DrugCandidate shape (`name`/`rxcui`/`status`).
    # Previously only `name`/`rxcui`/etc were probed → every candidate
    # serialized as {"name": None, "rxcui": None, ...}.
    serialized = []
    for c in candidates[:20]:
        if isinstance(c, dict):
            serialized.append(c)
        else:
            serialized.append({
                "name": getattr(c, "name", None) or getattr(c, "drug_name", None),
                "rxcui": getattr(c, "rxcui", None) or getattr(c, "drug_id", None),
                "score": getattr(c, "score", None),
                "mechanism": getattr(c, "mechanism", None),
                "status": getattr(c, "status", None) or getattr(c, "evidence_level", None),
                "for_disease": getattr(c, "for_disease", None) or getattr(c, "disease_name", None),
                "for_orpha": getattr(c, "for_orpha", None) or getattr(c, "orpha_code", None),
                "evidence": getattr(c, "evidence", None),
            })

    serialized = _enrich_with_sus(serialized, sus_region)

    return DrugSpec(
        candidates=serialized,
        model="kg_walks",
        n_evaluated=n_eval or len(serialized),
    )