gemeo-twin-stack / src /gemeo /encoder.py
timmers's picture
GEMEO world-model β€” initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""PatientEncoder β€” turn a ClinicalSnapshot into a vector.
Strategy:
1. **Bootstrap (works TODAY)**: weighted mean of pre-computed embeddings
from raras-app graph-ml β€” disease (Γ—2), phenotype (Γ—1), gene (Γ—1).
This matches what raras-app does in `generate-patient-embeddings.mjs`,
so embeddings live in the same space as `Patient.embedding` in Neo4j
and `/api/graph/similar-patients`.
2. **HGT (Phase 2, gemeo/train/hgt.py)**: trained heterogeneous graph
transformer that produces patient embeddings via attention over
the patient's HPO+Gene+Lab subgraph. Replaces step (1) when
`gemeo/artifacts/hgt_patient_encoder.pt` exists.
The fallback chain is deterministic β€” `encode()` always returns a vector
(even if zeros, signalled by `quality='empty'`).
"""
from __future__ import annotations
import logging
import os
from typing import Optional
from . import bridge
logger = logging.getLogger("gemeo.encoder")
DEFAULT_DIM = 3072 # matches raras-app fused embedding & Neo4j vector index
HGT_CKPT = os.environ.get(
"GEMEO_HGT_CKPT",
os.path.join(os.path.dirname(__file__), "artifacts", "hgt_patient_encoder.pt"),
)
def _l2_normalize(vec):
import numpy as np
n = np.linalg.norm(vec)
if n < 1e-9:
return vec
return vec / n
def encode_bootstrap(
phenotypes: list,
diseases: list,
genes: list,
*,
dim: int = DEFAULT_DIM,
weight_disease: float = 2.0,
weight_phenotype: float = 1.0,
weight_gene: float = 1.0,
):
"""Aggregate raras-app fused embeddings β€” same space as Neo4j Patient.embedding.
Args:
phenotypes: list of HPO ids ["HP:0001250", ...] or dicts {"hpo_id": ...}
diseases: list of ORPHA codes ["79253", ...] or dicts {"orpha": ...}
genes: list of HGNC symbols ["GBA", ...] or dicts {"symbol": ...}
Returns:
(vector: np.ndarray (dim,), quality: str)
quality ∈ {"empty", "partial", "full"}
"""
import numpy as np
def _ids(coll, key_options):
out = []
for c in coll or []:
if isinstance(c, str):
out.append(c)
elif isinstance(c, dict):
for k in key_options:
if c.get(k):
out.append(c[k]); break
return out
hpo_ids = _ids(phenotypes, ["hpo_id", "hpoId", "id"])
orpha_ids = _ids(diseases, ["orpha", "orpha_code", "orphaCode", "code"])
gene_ids = _ids(genes, ["symbol", "gene", "name"])
accum = np.zeros(dim, dtype=np.float32)
n = 0
hits = {"disease": 0, "phenotype": 0, "gene": 0}
misses = {"disease": 0, "phenotype": 0, "gene": 0}
for orpha in orpha_ids:
v = bridge.lookup_disease_embedding(str(orpha))
if v is not None and v.shape[0] == dim:
accum += weight_disease * v.astype(np.float32)
n += weight_disease
hits["disease"] += 1
else:
misses["disease"] += 1
for hpo in hpo_ids:
v = bridge.lookup_phenotype_embedding(str(hpo))
if v is not None and v.shape[0] == dim:
accum += weight_phenotype * v.astype(np.float32)
n += weight_phenotype
hits["phenotype"] += 1
else:
misses["phenotype"] += 1
for sym in gene_ids:
v = bridge.lookup_gene_embedding(str(sym).upper())
if v is not None and v.shape[0] == dim:
accum += weight_gene * v.astype(np.float32)
n += weight_gene
hits["gene"] += 1
else:
misses["gene"] += 1
if n == 0:
return np.zeros(dim, dtype=np.float32), "empty"
avg = accum / n
avg = _l2_normalize(avg)
n_input = len(hpo_ids) + len(orpha_ids) + len(gene_ids)
n_hit = sum(hits.values())
quality = "full" if n_hit == n_input else "partial"
return avg, quality
# ─── HGT slot (Phase 2) ────────────────────────────────────────────────────
_HGT_MODEL = None
def _try_load_hgt():
global _HGT_MODEL
if _HGT_MODEL is not None:
return _HGT_MODEL
if not os.path.exists(HGT_CKPT):
return None
try:
import torch
_HGT_MODEL = torch.load(HGT_CKPT, map_location="cpu", weights_only=False)
logger.info(f"Loaded HGT patient encoder from {HGT_CKPT}")
return _HGT_MODEL
except Exception as e:
logger.warning(f"HGT checkpoint exists but failed to load: {e}")
return None
def encode_hgt(snapshot_dict: dict, dim: int = DEFAULT_DIM):
"""Run the trained HGT patient encoder. Returns None if model unavailable."""
model = _try_load_hgt()
if model is None:
return None
try:
# The actual forward signature is defined by gemeo/train/hgt.py.
# We pass a structured dict with keys: phenotypes, genes, labs, diseases.
if hasattr(model, "encode_patient"):
return model.encode_patient(snapshot_dict)
return None
except Exception as e:
logger.error(f"HGT encode failed, falling back to bootstrap: {e}")
return None
def encode(snapshot_dict: dict, *, dim: int = DEFAULT_DIM):
"""Top-level encoder. Tries HGT, falls back to bootstrap.
snapshot_dict expected keys:
phenotypes: [{hpo_id, name, ...}, ...]
diseases: [{orpha, ...}, ...] (optional β€” confirmed/probable dx)
genes: [{symbol, ...}, ...]
"""
v = encode_hgt(snapshot_dict, dim=dim)
if v is not None:
return v, "hgt"
vec, qual = encode_bootstrap(
phenotypes=snapshot_dict.get("phenotypes", []),
diseases=snapshot_dict.get("diseases", []),
genes=snapshot_dict.get("genes", []),
dim=dim,
)
return vec, f"bootstrap_{qual}"
def encode_patient_space(space) -> tuple:
"""Convenience: encode a `PatientSpace` object directly."""
snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None
if snap is None:
# Fallback: no snapshots, build from hypotheses + recent events
diseases = []
for hyp in (getattr(space, "_hypotheses", {}) or {}).values():
if hasattr(hyp, "orpha_code") and hyp.orpha_code:
diseases.append({"orpha": hyp.orpha_code})
return encode({
"phenotypes": [],
"diseases": diseases,
"genes": [],
})
return encode({
"phenotypes": snap.phenotypes,
"diseases": snap.diagnoses,
"genes": snap.genes,
})