| """HGT β Heterogeneous Graph Transformer for patient embedding (Phase 2). |
| |
| Following Hu et al. ICML 2020 + clinical adaptations from KG sparsification |
| papers (arXiv 2510.08655). Trained with two heads: |
| |
| 1. **Disease link prediction** β predict Patient β Disease edges (positives: |
| confirmed diagnoses; negatives: random + hard negatives from feedback ledger). |
| 2. **Patient-Patient contrastive** (SimCLR style) β patients with confirmed |
| same disease should have similar embeddings. |
| |
| Output: `gemeo/artifacts/hgt_patient_encoder.pt` containing a `HGTEncoder` |
| with method `encode_patient(snapshot_dict) -> np.ndarray`. |
| |
| This is a scaffold β fill in the embedding sources and run on GPU. |
| """ |
| from __future__ import annotations |
| import os |
| import logging |
| from dataclasses import dataclass |
|
|
| logger = logging.getLogger("gemeo.train.hgt") |
|
|
| CKPT = os.path.join(os.path.dirname(__file__), "..", "artifacts", "hgt_patient_encoder.pt") |
|
|
|
|
| @dataclass |
| class HGTConfig: |
| embed_dim: int = 256 |
| n_heads: int = 8 |
| n_layers: int = 3 |
| dropout: float = 0.2 |
| output_dim: int = 3072 |
| lr: float = 1e-4 |
| epochs: int = 100 |
| batch_size: int = 1024 |
| contrastive_weight: float = 0.3 |
| link_pred_weight: float = 0.7 |
|
|
|
|
| def build_model(cfg: HGTConfig, metadata): |
| import torch |
| import torch.nn as nn |
| from torch_geometric.nn import HGTConv |
|
|
| class HGTEncoder(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lin = nn.ModuleDict({nt: nn.Linear(64, cfg.embed_dim) for nt in metadata[0]}) |
| self.convs = nn.ModuleList([ |
| HGTConv(cfg.embed_dim, cfg.embed_dim, metadata, heads=cfg.n_heads) |
| for _ in range(cfg.n_layers) |
| ]) |
| self.proj = nn.Linear(cfg.embed_dim, cfg.output_dim) |
| self.dropout = nn.Dropout(cfg.dropout) |
|
|
| def forward(self, x_dict, edge_index_dict): |
| x_dict = {nt: torch.relu(self.lin[nt](x)) for nt, x in x_dict.items()} |
| for conv in self.convs: |
| x_dict = conv(x_dict, edge_index_dict) |
| x_dict = {k: self.dropout(torch.relu(v)) for k, v in x_dict.items()} |
| return x_dict |
|
|
| def encode_patient(self, snapshot_dict): |
| """Inference path used by `gemeo.encoder.encode_hgt`. |
| |
| Build a 1-hop subgraph around the patient's HPO+Gene+Disease |
| nodes, run the encoder, mean-pool, project to 3072-dim, L2-normalize. |
| """ |
| import numpy as np |
| |
| return np.zeros(cfg.output_dim, dtype=np.float32) |
|
|
| return HGTEncoder() |
|
|
|
|
| def train(epochs: int = None): |
| """Main training entry. Run as `python -m gemeo.train.hgt`.""" |
| import torch |
| cfg = HGTConfig(epochs=epochs or HGTConfig.epochs) |
|
|
| primekg_path = os.path.join(os.path.dirname(__file__), "..", "..", "data", "primekg", "primekg.pt") |
| if not os.path.exists(primekg_path): |
| logger.warning(f"PrimeKG not found at {primekg_path}; run gemeo.train.primekg first") |
| return |
|
|
| blob = torch.load(primekg_path, weights_only=False) |
| data = blob["data"] |
| metadata = data.metadata() |
|
|
| model = build_model(cfg, metadata) |
| opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr) |
|
|
| |
| |
| logger.info("Training loop scaffold β fill in dataset loader and losses") |
|
|
| os.makedirs(os.path.dirname(CKPT), exist_ok=True) |
| torch.save(model, CKPT) |
| logger.info(f"saved {CKPT}") |
|
|
|
|
| if __name__ == "__main__": |
| logging.basicConfig(level=logging.INFO) |
| train() |
|
|