| """Drug repurposing — TxGNN-style link prediction Disease ↔ Drug. |
| |
| Bootstrap: wraps `drug_repurposer.find_repurposing_candidates` (KG walks |
| Disease → Gene → Drug via Cypher). |
| |
| Phase 2: TxGNN fine-tune on PrimeKG + our enriched KG (gemeo/train/txgnn.py). |
| When checkpoint exists, overrides the bootstrap path. |
| |
| We also enrich each candidate with SUS dispensation status — whether |
| the drug is in any PCDT and whether it's available in the patient's UF. |
| """ |
| from __future__ import annotations |
| import logging |
| import os |
| from typing import Optional |
|
|
| from .types import DrugSpec |
|
|
| logger = logging.getLogger("gemeo.repurpose") |
|
|
| TXGNN_CKPT = os.environ.get( |
| "GEMEO_TXGNN_CKPT", |
| os.path.join(os.path.dirname(__file__), "artifacts", "txgnn.pt"), |
| ) |
|
|
|
|
| async def _try_txgnn(space, embedding): |
| if not os.path.exists(TXGNN_CKPT): |
| return None |
| try: |
| from .train import txgnn as tx_mod |
| return await tx_mod.predict(space, embedding, TXGNN_CKPT) |
| except Exception as e: |
| logger.warning(f"TxGNN predict failed: {e}") |
| return None |
|
|
|
|
| def _enrich_with_sus(candidates: list, sus_region: Optional[str]) -> list: |
| """Mark each candidate with PCDT availability + UF dispensation.""" |
| try: |
| from brazilian_context import get_pcdt |
| except ImportError: |
| return candidates |
|
|
| out = [] |
| for c in candidates: |
| item = dict(c) if isinstance(c, dict) else { |
| "name": getattr(c, "name", None), |
| "rxcui": getattr(c, "rxcui", None), |
| "score": getattr(c, "score", None), |
| "mechanism": getattr(c, "mechanism", None), |
| "status": getattr(c, "status", None), |
| "for_disease": getattr(c, "for_disease", None), |
| "for_orpha": getattr(c, "for_orpha", None), |
| } |
| sus_avail = False |
| in_pcdt = False |
| if item.get("for_orpha"): |
| try: |
| pcdt = get_pcdt(item["for_orpha"]) |
| except Exception: |
| pcdt = None |
| if pcdt: |
| in_pcdt = True |
| |
| therapies = (pcdt.get("therapies") or []) + (pcdt.get("medicamentos") or []) |
| drug_name_lower = (item.get("name") or "").lower() |
| if drug_name_lower and any(drug_name_lower in (str(t) or "").lower() for t in therapies): |
| sus_avail = True |
| item["sus_in_pcdt"] = in_pcdt |
| item["sus_dispensed"] = sus_avail |
| item["sus_uf"] = sus_region |
| out.append(item) |
| return out |
|
|
|
|
| async def find(space, embedding=None, sus_region: Optional[str] = None) -> DrugSpec: |
| """Find drug candidates for the digital twin.""" |
|
|
| |
| spec = await _try_txgnn(space, embedding) |
| if spec is not None: |
| spec.candidates = _enrich_with_sus(spec.candidates, sus_region) |
| return spec |
|
|
| candidates = [] |
| n_eval = 0 |
| try: |
| from drug_repurposer import find_repurposing_candidates |
| result = await find_repurposing_candidates(space) |
| if result is None: |
| pass |
| elif isinstance(result, dict): |
| candidates = result.get("candidates", []) or [] |
| n_eval = int(result.get("n_evaluated", len(candidates))) |
| else: |
| candidates = getattr(result, "candidates", []) or [] |
| n_eval = int(getattr(result, "n_evaluated", len(candidates)) or len(candidates)) |
| except Exception as e: |
| logger.debug(f"drug_repurposer failed: {e}") |
|
|
| |
| |
| |
| |
| |
| |
| serialized = [] |
| for c in candidates[:20]: |
| if isinstance(c, dict): |
| serialized.append(c) |
| else: |
| serialized.append({ |
| "name": getattr(c, "name", None) or getattr(c, "drug_name", None), |
| "rxcui": getattr(c, "rxcui", None) or getattr(c, "drug_id", None), |
| "score": getattr(c, "score", None), |
| "mechanism": getattr(c, "mechanism", None), |
| "status": getattr(c, "status", None) or getattr(c, "evidence_level", None), |
| "for_disease": getattr(c, "for_disease", None) or getattr(c, "disease_name", None), |
| "for_orpha": getattr(c, "for_orpha", None) or getattr(c, "orpha_code", None), |
| "evidence": getattr(c, "evidence", None), |
| }) |
|
|
| serialized = _enrich_with_sus(serialized, sus_region) |
|
|
| return DrugSpec( |
| candidates=serialized, |
| model="kg_walks", |
| n_evaluated=n_eval or len(serialized), |
| ) |
|
|