timmers's picture
GEMEO world-model β€” initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""HGT β€” Heterogeneous Graph Transformer for patient embedding (Phase 2).
Following Hu et al. ICML 2020 + clinical adaptations from KG sparsification
papers (arXiv 2510.08655). Trained with two heads:
1. **Disease link prediction** β€” predict Patient β†’ Disease edges (positives:
confirmed diagnoses; negatives: random + hard negatives from feedback ledger).
2. **Patient-Patient contrastive** (SimCLR style) β€” patients with confirmed
same disease should have similar embeddings.
Output: `gemeo/artifacts/hgt_patient_encoder.pt` containing a `HGTEncoder`
with method `encode_patient(snapshot_dict) -> np.ndarray`.
This is a scaffold β€” fill in the embedding sources and run on GPU.
"""
from __future__ import annotations
import os
import logging
from dataclasses import dataclass
logger = logging.getLogger("gemeo.train.hgt")
CKPT = os.path.join(os.path.dirname(__file__), "..", "artifacts", "hgt_patient_encoder.pt")
@dataclass
class HGTConfig:
embed_dim: int = 256
n_heads: int = 8
n_layers: int = 3
dropout: float = 0.2
output_dim: int = 3072 # match raras-app fused embedding dim
lr: float = 1e-4
epochs: int = 100
batch_size: int = 1024
contrastive_weight: float = 0.3
link_pred_weight: float = 0.7
def build_model(cfg: HGTConfig, metadata):
import torch
import torch.nn as nn
from torch_geometric.nn import HGTConv
class HGTEncoder(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.ModuleDict({nt: nn.Linear(64, cfg.embed_dim) for nt in metadata[0]})
self.convs = nn.ModuleList([
HGTConv(cfg.embed_dim, cfg.embed_dim, metadata, heads=cfg.n_heads)
for _ in range(cfg.n_layers)
])
self.proj = nn.Linear(cfg.embed_dim, cfg.output_dim)
self.dropout = nn.Dropout(cfg.dropout)
def forward(self, x_dict, edge_index_dict):
x_dict = {nt: torch.relu(self.lin[nt](x)) for nt, x in x_dict.items()}
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
x_dict = {k: self.dropout(torch.relu(v)) for k, v in x_dict.items()}
return x_dict
def encode_patient(self, snapshot_dict):
"""Inference path used by `gemeo.encoder.encode_hgt`.
Build a 1-hop subgraph around the patient's HPO+Gene+Disease
nodes, run the encoder, mean-pool, project to 3072-dim, L2-normalize.
"""
import numpy as np
# Stub: returns zero vector. Real implementation in PR-2.
return np.zeros(cfg.output_dim, dtype=np.float32)
return HGTEncoder()
def train(epochs: int = None):
"""Main training entry. Run as `python -m gemeo.train.hgt`."""
import torch
cfg = HGTConfig(epochs=epochs or HGTConfig.epochs)
primekg_path = os.path.join(os.path.dirname(__file__), "..", "..", "data", "primekg", "primekg.pt")
if not os.path.exists(primekg_path):
logger.warning(f"PrimeKG not found at {primekg_path}; run gemeo.train.primekg first")
return
blob = torch.load(primekg_path, weights_only=False)
data = blob["data"]
metadata = data.metadata()
model = build_model(cfg, metadata)
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
# TODO: build train loader from feedback.jsonl + Neo4j patient labels
# TODO: run loop with link-pred + contrastive losses
logger.info("Training loop scaffold β€” fill in dataset loader and losses")
os.makedirs(os.path.dirname(CKPT), exist_ok=True)
torch.save(model, CKPT)
logger.info(f"saved {CKPT}")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
train()