| """TGNN β Temporal Graph Network for patient trajectory (Phase 2). |
| |
| Architecture follows TRANS (IJCAI 2024) + TETGN improvements: |
| - Temporal heterogeneous graph: nodes are (patient_snapshot at time t). |
| - Each snapshot is connected to phenotypes/labs/genes valid at that time. |
| - Trajectory-aware aggregation (TETGN) β message passing weighted by |
| elapsed time and clinical-event type. |
| - Multi-task heads: next phenotype, next complication, time-to-event |
| (lung disease, hospitalization, etc.). |
| |
| Trained on the snapshot chains stored in PatientSpace + Neo4j ClinicalSnapshot. |
| """ |
| from __future__ import annotations |
| import os |
| import logging |
| from typing import Optional |
|
|
| logger = logging.getLogger("gemeo.train.tgnn") |
|
|
| CKPT = os.path.join(os.path.dirname(__file__), "..", "artifacts", "tgnn_trajectory.pt") |
|
|
|
|
| async def predict(space, ckpt_path: str): |
| """Inference path used by `gemeo.trajectory.predict`.""" |
| if not os.path.exists(ckpt_path): |
| return None |
| try: |
| import torch |
| except ImportError: |
| return None |
| |
| |
| return None |
|
|
|
|
| def train(epochs: int = 60): |
| """Run as `python -m gemeo.train.tgnn`.""" |
| logger.info("TGNN scaffold β needs longitudinal labels from PatientSpace snapshots") |
|
|
|
|
| if __name__ == "__main__": |
| logging.basicConfig(level=logging.INFO) |
| train() |
|
|