timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""TxGNN fine-tune for drug repurposing (Phase 2).
Following Huang et al. *Nature Medicine* 2024. Zero-shot disease–drug link
prediction using a heterogeneous GNN over PrimeKG + our enriched KG.
Approach:
1. Initialize node features with our raras-app `fused_embeddings.npz`.
2. Train HGT-style message passing on PrimeKG drug-disease subgraph.
3. Add Brazilian SUS auxiliary head — bias predictions toward drugs in
PCDT/CEAF when the patient is in the SUS context.
4. Save inference module exposing `predict(space, embedding, ckpt) -> DrugSpec`.
"""
from __future__ import annotations
import os
import logging
from typing import Optional
logger = logging.getLogger("gemeo.train.txgnn")
CKPT = os.path.join(os.path.dirname(__file__), "..", "artifacts", "txgnn.pt")
async def predict(space, embedding, ckpt_path: str):
"""Inference path used by `gemeo.repurpose.find`."""
if not os.path.exists(ckpt_path):
return None
try:
import torch # noqa: F401
except ImportError:
return None
# Load checkpoint, run forward, return DrugSpec.
# Stub: defer to bootstrap. Real impl in PR-2.
return None
def train(epochs: int = 80):
"""Run as `python -m gemeo.train.txgnn`."""
logger.info("TxGNN scaffold — fill PrimeKG drug-disease loader + link-pred head")
# TODO: load PrimeKG, freeze HGT base, fine-tune drug-disease head
# TODO: SUS auxiliary loss using brazilian_context.PCDT mappings
# TODO: save model with .predict(...) method
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
train()