"""HGT — Heterogeneous Graph Transformer for patient embedding (Phase 2). Following Hu et al. ICML 2020 + clinical adaptations from KG sparsification papers (arXiv 2510.08655). Trained with two heads: 1. **Disease link prediction** — predict Patient → Disease edges (positives: confirmed diagnoses; negatives: random + hard negatives from feedback ledger). 2. **Patient-Patient contrastive** (SimCLR style) — patients with confirmed same disease should have similar embeddings. Output: `gemeo/artifacts/hgt_patient_encoder.pt` containing a `HGTEncoder` with method `encode_patient(snapshot_dict) -> np.ndarray`. This is a scaffold — fill in the embedding sources and run on GPU. """ from __future__ import annotations import os import logging from dataclasses import dataclass logger = logging.getLogger("gemeo.train.hgt") CKPT = os.path.join(os.path.dirname(__file__), "..", "artifacts", "hgt_patient_encoder.pt") @dataclass class HGTConfig: embed_dim: int = 256 n_heads: int = 8 n_layers: int = 3 dropout: float = 0.2 output_dim: int = 3072 # match raras-app fused embedding dim lr: float = 1e-4 epochs: int = 100 batch_size: int = 1024 contrastive_weight: float = 0.3 link_pred_weight: float = 0.7 def build_model(cfg: HGTConfig, metadata): import torch import torch.nn as nn from torch_geometric.nn import HGTConv class HGTEncoder(nn.Module): def __init__(self): super().__init__() self.lin = nn.ModuleDict({nt: nn.Linear(64, cfg.embed_dim) for nt in metadata[0]}) self.convs = nn.ModuleList([ HGTConv(cfg.embed_dim, cfg.embed_dim, metadata, heads=cfg.n_heads) for _ in range(cfg.n_layers) ]) self.proj = nn.Linear(cfg.embed_dim, cfg.output_dim) self.dropout = nn.Dropout(cfg.dropout) def forward(self, x_dict, edge_index_dict): x_dict = {nt: torch.relu(self.lin[nt](x)) for nt, x in x_dict.items()} for conv in self.convs: x_dict = conv(x_dict, edge_index_dict) x_dict = {k: self.dropout(torch.relu(v)) for k, v in x_dict.items()} return x_dict def encode_patient(self, snapshot_dict): """Inference path used by `gemeo.encoder.encode_hgt`. Build a 1-hop subgraph around the patient's HPO+Gene+Disease nodes, run the encoder, mean-pool, project to 3072-dim, L2-normalize. """ import numpy as np # Stub: returns zero vector. Real implementation in PR-2. return np.zeros(cfg.output_dim, dtype=np.float32) return HGTEncoder() def train(epochs: int = None): """Main training entry. Run as `python -m gemeo.train.hgt`.""" import torch cfg = HGTConfig(epochs=epochs or HGTConfig.epochs) primekg_path = os.path.join(os.path.dirname(__file__), "..", "..", "data", "primekg", "primekg.pt") if not os.path.exists(primekg_path): logger.warning(f"PrimeKG not found at {primekg_path}; run gemeo.train.primekg first") return blob = torch.load(primekg_path, weights_only=False) data = blob["data"] metadata = data.metadata() model = build_model(cfg, metadata) opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr) # TODO: build train loader from feedback.jsonl + Neo4j patient labels # TODO: run loop with link-pred + contrastive losses logger.info("Training loop scaffold — fill in dataset loader and losses") os.makedirs(os.path.dirname(CKPT), exist_ok=True) torch.save(model, CKPT) logger.info(f"saved {CKPT}") if __name__ == "__main__": logging.basicConfig(level=logging.INFO) train()