File size: 4,951 Bytes
089d665 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """Drug repurposing — TxGNN-style link prediction Disease ↔ Drug.
Bootstrap: wraps `drug_repurposer.find_repurposing_candidates` (KG walks
Disease → Gene → Drug via Cypher).
Phase 2: TxGNN fine-tune on PrimeKG + our enriched KG (gemeo/train/txgnn.py).
When checkpoint exists, overrides the bootstrap path.
We also enrich each candidate with SUS dispensation status — whether
the drug is in any PCDT and whether it's available in the patient's UF.
"""
from __future__ import annotations
import logging
import os
from typing import Optional
from .types import DrugSpec
logger = logging.getLogger("gemeo.repurpose")
TXGNN_CKPT = os.environ.get(
"GEMEO_TXGNN_CKPT",
os.path.join(os.path.dirname(__file__), "artifacts", "txgnn.pt"),
)
async def _try_txgnn(space, embedding):
if not os.path.exists(TXGNN_CKPT):
return None
try:
from .train import txgnn as tx_mod
return await tx_mod.predict(space, embedding, TXGNN_CKPT)
except Exception as e:
logger.warning(f"TxGNN predict failed: {e}")
return None
def _enrich_with_sus(candidates: list, sus_region: Optional[str]) -> list:
"""Mark each candidate with PCDT availability + UF dispensation."""
try:
from brazilian_context import get_pcdt
except ImportError:
return candidates
out = []
for c in candidates:
item = dict(c) if isinstance(c, dict) else {
"name": getattr(c, "name", None),
"rxcui": getattr(c, "rxcui", None),
"score": getattr(c, "score", None),
"mechanism": getattr(c, "mechanism", None),
"status": getattr(c, "status", None),
"for_disease": getattr(c, "for_disease", None),
"for_orpha": getattr(c, "for_orpha", None),
}
sus_avail = False
in_pcdt = False
if item.get("for_orpha"):
try:
pcdt = get_pcdt(item["for_orpha"])
except Exception:
pcdt = None
if pcdt:
in_pcdt = True
# heuristic: drug name appears in PCDT therapy list
therapies = (pcdt.get("therapies") or []) + (pcdt.get("medicamentos") or [])
drug_name_lower = (item.get("name") or "").lower()
if drug_name_lower and any(drug_name_lower in (str(t) or "").lower() for t in therapies):
sus_avail = True
item["sus_in_pcdt"] = in_pcdt
item["sus_dispensed"] = sus_avail # TODO: cross-ref APAC by UF when available
item["sus_uf"] = sus_region
out.append(item)
return out
async def find(space, embedding=None, sus_region: Optional[str] = None) -> DrugSpec:
"""Find drug candidates for the digital twin."""
# try TxGNN first
spec = await _try_txgnn(space, embedding)
if spec is not None:
spec.candidates = _enrich_with_sus(spec.candidates, sus_region)
return spec
candidates = []
n_eval = 0
try:
from drug_repurposer import find_repurposing_candidates
result = await find_repurposing_candidates(space)
if result is None:
pass
elif isinstance(result, dict):
candidates = result.get("candidates", []) or []
n_eval = int(result.get("n_evaluated", len(candidates)))
else:
candidates = getattr(result, "candidates", []) or []
n_eval = int(getattr(result, "n_evaluated", len(candidates)) or len(candidates))
except Exception as e:
logger.debug(f"drug_repurposer failed: {e}")
# serialize candidate objects
# The upstream `drug_repurposer.DrugCandidate` dataclass uses field
# names `drug_name`/`drug_id`/`evidence_level`, so we map them into
# the gemeo.types.DrugCandidate shape (`name`/`rxcui`/`status`).
# Previously only `name`/`rxcui`/etc were probed → every candidate
# serialized as {"name": None, "rxcui": None, ...}.
serialized = []
for c in candidates[:20]:
if isinstance(c, dict):
serialized.append(c)
else:
serialized.append({
"name": getattr(c, "name", None) or getattr(c, "drug_name", None),
"rxcui": getattr(c, "rxcui", None) or getattr(c, "drug_id", None),
"score": getattr(c, "score", None),
"mechanism": getattr(c, "mechanism", None),
"status": getattr(c, "status", None) or getattr(c, "evidence_level", None),
"for_disease": getattr(c, "for_disease", None) or getattr(c, "disease_name", None),
"for_orpha": getattr(c, "for_orpha", None) or getattr(c, "orpha_code", None),
"evidence": getattr(c, "evidence", None),
})
serialized = _enrich_with_sus(serialized, sus_region)
return DrugSpec(
candidates=serialized,
model="kg_walks",
n_evaluated=n_eval or len(serialized),
)
|