gemeo-twin-stack / src /gemeo /repurpose.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Drug repurposing — TxGNN-style link prediction Disease ↔ Drug.
Bootstrap: wraps `drug_repurposer.find_repurposing_candidates` (KG walks
Disease → Gene → Drug via Cypher).
Phase 2: TxGNN fine-tune on PrimeKG + our enriched KG (gemeo/train/txgnn.py).
When checkpoint exists, overrides the bootstrap path.
We also enrich each candidate with SUS dispensation status — whether
the drug is in any PCDT and whether it's available in the patient's UF.
"""
from __future__ import annotations
import logging
import os
from typing import Optional
from .types import DrugSpec
logger = logging.getLogger("gemeo.repurpose")
TXGNN_CKPT = os.environ.get(
"GEMEO_TXGNN_CKPT",
os.path.join(os.path.dirname(__file__), "artifacts", "txgnn.pt"),
)
async def _try_txgnn(space, embedding):
if not os.path.exists(TXGNN_CKPT):
return None
try:
from .train import txgnn as tx_mod
return await tx_mod.predict(space, embedding, TXGNN_CKPT)
except Exception as e:
logger.warning(f"TxGNN predict failed: {e}")
return None
def _enrich_with_sus(candidates: list, sus_region: Optional[str]) -> list:
"""Mark each candidate with PCDT availability + UF dispensation."""
try:
from brazilian_context import get_pcdt
except ImportError:
return candidates
out = []
for c in candidates:
item = dict(c) if isinstance(c, dict) else {
"name": getattr(c, "name", None),
"rxcui": getattr(c, "rxcui", None),
"score": getattr(c, "score", None),
"mechanism": getattr(c, "mechanism", None),
"status": getattr(c, "status", None),
"for_disease": getattr(c, "for_disease", None),
"for_orpha": getattr(c, "for_orpha", None),
}
sus_avail = False
in_pcdt = False
if item.get("for_orpha"):
try:
pcdt = get_pcdt(item["for_orpha"])
except Exception:
pcdt = None
if pcdt:
in_pcdt = True
# heuristic: drug name appears in PCDT therapy list
therapies = (pcdt.get("therapies") or []) + (pcdt.get("medicamentos") or [])
drug_name_lower = (item.get("name") or "").lower()
if drug_name_lower and any(drug_name_lower in (str(t) or "").lower() for t in therapies):
sus_avail = True
item["sus_in_pcdt"] = in_pcdt
item["sus_dispensed"] = sus_avail # TODO: cross-ref APAC by UF when available
item["sus_uf"] = sus_region
out.append(item)
return out
async def find(space, embedding=None, sus_region: Optional[str] = None) -> DrugSpec:
"""Find drug candidates for the digital twin."""
# try TxGNN first
spec = await _try_txgnn(space, embedding)
if spec is not None:
spec.candidates = _enrich_with_sus(spec.candidates, sus_region)
return spec
candidates = []
n_eval = 0
try:
from drug_repurposer import find_repurposing_candidates
result = await find_repurposing_candidates(space)
if result is None:
pass
elif isinstance(result, dict):
candidates = result.get("candidates", []) or []
n_eval = int(result.get("n_evaluated", len(candidates)))
else:
candidates = getattr(result, "candidates", []) or []
n_eval = int(getattr(result, "n_evaluated", len(candidates)) or len(candidates))
except Exception as e:
logger.debug(f"drug_repurposer failed: {e}")
# serialize candidate objects
# The upstream `drug_repurposer.DrugCandidate` dataclass uses field
# names `drug_name`/`drug_id`/`evidence_level`, so we map them into
# the gemeo.types.DrugCandidate shape (`name`/`rxcui`/`status`).
# Previously only `name`/`rxcui`/etc were probed → every candidate
# serialized as {"name": None, "rxcui": None, ...}.
serialized = []
for c in candidates[:20]:
if isinstance(c, dict):
serialized.append(c)
else:
serialized.append({
"name": getattr(c, "name", None) or getattr(c, "drug_name", None),
"rxcui": getattr(c, "rxcui", None) or getattr(c, "drug_id", None),
"score": getattr(c, "score", None),
"mechanism": getattr(c, "mechanism", None),
"status": getattr(c, "status", None) or getattr(c, "evidence_level", None),
"for_disease": getattr(c, "for_disease", None) or getattr(c, "disease_name", None),
"for_orpha": getattr(c, "for_orpha", None) or getattr(c, "orpha_code", None),
"evidence": getattr(c, "evidence", None),
})
serialized = _enrich_with_sus(serialized, sus_region)
return DrugSpec(
candidates=serialized,
model="kg_walks",
n_evaluated=n_eval or len(serialized),
)