File size: 6,626 Bytes
089d665 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """PatientEncoder β turn a ClinicalSnapshot into a vector.
Strategy:
1. **Bootstrap (works TODAY)**: weighted mean of pre-computed embeddings
from raras-app graph-ml β disease (Γ2), phenotype (Γ1), gene (Γ1).
This matches what raras-app does in `generate-patient-embeddings.mjs`,
so embeddings live in the same space as `Patient.embedding` in Neo4j
and `/api/graph/similar-patients`.
2. **HGT (Phase 2, gemeo/train/hgt.py)**: trained heterogeneous graph
transformer that produces patient embeddings via attention over
the patient's HPO+Gene+Lab subgraph. Replaces step (1) when
`gemeo/artifacts/hgt_patient_encoder.pt` exists.
The fallback chain is deterministic β `encode()` always returns a vector
(even if zeros, signalled by `quality='empty'`).
"""
from __future__ import annotations
import logging
import os
from typing import Optional
from . import bridge
logger = logging.getLogger("gemeo.encoder")
DEFAULT_DIM = 3072 # matches raras-app fused embedding & Neo4j vector index
HGT_CKPT = os.environ.get(
"GEMEO_HGT_CKPT",
os.path.join(os.path.dirname(__file__), "artifacts", "hgt_patient_encoder.pt"),
)
def _l2_normalize(vec):
import numpy as np
n = np.linalg.norm(vec)
if n < 1e-9:
return vec
return vec / n
def encode_bootstrap(
phenotypes: list,
diseases: list,
genes: list,
*,
dim: int = DEFAULT_DIM,
weight_disease: float = 2.0,
weight_phenotype: float = 1.0,
weight_gene: float = 1.0,
):
"""Aggregate raras-app fused embeddings β same space as Neo4j Patient.embedding.
Args:
phenotypes: list of HPO ids ["HP:0001250", ...] or dicts {"hpo_id": ...}
diseases: list of ORPHA codes ["79253", ...] or dicts {"orpha": ...}
genes: list of HGNC symbols ["GBA", ...] or dicts {"symbol": ...}
Returns:
(vector: np.ndarray (dim,), quality: str)
quality β {"empty", "partial", "full"}
"""
import numpy as np
def _ids(coll, key_options):
out = []
for c in coll or []:
if isinstance(c, str):
out.append(c)
elif isinstance(c, dict):
for k in key_options:
if c.get(k):
out.append(c[k]); break
return out
hpo_ids = _ids(phenotypes, ["hpo_id", "hpoId", "id"])
orpha_ids = _ids(diseases, ["orpha", "orpha_code", "orphaCode", "code"])
gene_ids = _ids(genes, ["symbol", "gene", "name"])
accum = np.zeros(dim, dtype=np.float32)
n = 0
hits = {"disease": 0, "phenotype": 0, "gene": 0}
misses = {"disease": 0, "phenotype": 0, "gene": 0}
for orpha in orpha_ids:
v = bridge.lookup_disease_embedding(str(orpha))
if v is not None and v.shape[0] == dim:
accum += weight_disease * v.astype(np.float32)
n += weight_disease
hits["disease"] += 1
else:
misses["disease"] += 1
for hpo in hpo_ids:
v = bridge.lookup_phenotype_embedding(str(hpo))
if v is not None and v.shape[0] == dim:
accum += weight_phenotype * v.astype(np.float32)
n += weight_phenotype
hits["phenotype"] += 1
else:
misses["phenotype"] += 1
for sym in gene_ids:
v = bridge.lookup_gene_embedding(str(sym).upper())
if v is not None and v.shape[0] == dim:
accum += weight_gene * v.astype(np.float32)
n += weight_gene
hits["gene"] += 1
else:
misses["gene"] += 1
if n == 0:
return np.zeros(dim, dtype=np.float32), "empty"
avg = accum / n
avg = _l2_normalize(avg)
n_input = len(hpo_ids) + len(orpha_ids) + len(gene_ids)
n_hit = sum(hits.values())
quality = "full" if n_hit == n_input else "partial"
return avg, quality
# βββ HGT slot (Phase 2) ββββββββββββββββββββββββββββββββββββββββββββββββββββ
_HGT_MODEL = None
def _try_load_hgt():
global _HGT_MODEL
if _HGT_MODEL is not None:
return _HGT_MODEL
if not os.path.exists(HGT_CKPT):
return None
try:
import torch
_HGT_MODEL = torch.load(HGT_CKPT, map_location="cpu", weights_only=False)
logger.info(f"Loaded HGT patient encoder from {HGT_CKPT}")
return _HGT_MODEL
except Exception as e:
logger.warning(f"HGT checkpoint exists but failed to load: {e}")
return None
def encode_hgt(snapshot_dict: dict, dim: int = DEFAULT_DIM):
"""Run the trained HGT patient encoder. Returns None if model unavailable."""
model = _try_load_hgt()
if model is None:
return None
try:
# The actual forward signature is defined by gemeo/train/hgt.py.
# We pass a structured dict with keys: phenotypes, genes, labs, diseases.
if hasattr(model, "encode_patient"):
return model.encode_patient(snapshot_dict)
return None
except Exception as e:
logger.error(f"HGT encode failed, falling back to bootstrap: {e}")
return None
def encode(snapshot_dict: dict, *, dim: int = DEFAULT_DIM):
"""Top-level encoder. Tries HGT, falls back to bootstrap.
snapshot_dict expected keys:
phenotypes: [{hpo_id, name, ...}, ...]
diseases: [{orpha, ...}, ...] (optional β confirmed/probable dx)
genes: [{symbol, ...}, ...]
"""
v = encode_hgt(snapshot_dict, dim=dim)
if v is not None:
return v, "hgt"
vec, qual = encode_bootstrap(
phenotypes=snapshot_dict.get("phenotypes", []),
diseases=snapshot_dict.get("diseases", []),
genes=snapshot_dict.get("genes", []),
dim=dim,
)
return vec, f"bootstrap_{qual}"
def encode_patient_space(space) -> tuple:
"""Convenience: encode a `PatientSpace` object directly."""
snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None
if snap is None:
# Fallback: no snapshots, build from hypotheses + recent events
diseases = []
for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
if hasattr(hyp, "orpha_code") and hyp.orpha_code:
diseases.append({"orpha": hyp.orpha_code})
return encode({
"phenotypes": [],
"diseases": diseases,
"genes": [],
})
return encode({
"phenotypes": snap.phenotypes,
"diseases": snap.diagnoses,
"genes": snap.genes,
})
|