| """PatientEncoder β turn a ClinicalSnapshot into a vector. |
| |
| Strategy: |
| 1. **Bootstrap (works TODAY)**: weighted mean of pre-computed embeddings |
| from raras-app graph-ml β disease (Γ2), phenotype (Γ1), gene (Γ1). |
| This matches what raras-app does in `generate-patient-embeddings.mjs`, |
| so embeddings live in the same space as `Patient.embedding` in Neo4j |
| and `/api/graph/similar-patients`. |
| |
| 2. **HGT (Phase 2, gemeo/train/hgt.py)**: trained heterogeneous graph |
| transformer that produces patient embeddings via attention over |
| the patient's HPO+Gene+Lab subgraph. Replaces step (1) when |
| `gemeo/artifacts/hgt_patient_encoder.pt` exists. |
| |
| The fallback chain is deterministic β `encode()` always returns a vector |
| (even if zeros, signalled by `quality='empty'`). |
| """ |
| from __future__ import annotations |
| import logging |
| import os |
| from typing import Optional |
|
|
| from . import bridge |
|
|
| logger = logging.getLogger("gemeo.encoder") |
|
|
| DEFAULT_DIM = 3072 |
| HGT_CKPT = os.environ.get( |
| "GEMEO_HGT_CKPT", |
| os.path.join(os.path.dirname(__file__), "artifacts", "hgt_patient_encoder.pt"), |
| ) |
|
|
|
|
| def _l2_normalize(vec): |
| import numpy as np |
| n = np.linalg.norm(vec) |
| if n < 1e-9: |
| return vec |
| return vec / n |
|
|
|
|
| def encode_bootstrap( |
| phenotypes: list, |
| diseases: list, |
| genes: list, |
| *, |
| dim: int = DEFAULT_DIM, |
| weight_disease: float = 2.0, |
| weight_phenotype: float = 1.0, |
| weight_gene: float = 1.0, |
| ): |
| """Aggregate raras-app fused embeddings β same space as Neo4j Patient.embedding. |
| |
| Args: |
| phenotypes: list of HPO ids ["HP:0001250", ...] or dicts {"hpo_id": ...} |
| diseases: list of ORPHA codes ["79253", ...] or dicts {"orpha": ...} |
| genes: list of HGNC symbols ["GBA", ...] or dicts {"symbol": ...} |
| |
| Returns: |
| (vector: np.ndarray (dim,), quality: str) |
| quality β {"empty", "partial", "full"} |
| """ |
| import numpy as np |
|
|
| def _ids(coll, key_options): |
| out = [] |
| for c in coll or []: |
| if isinstance(c, str): |
| out.append(c) |
| elif isinstance(c, dict): |
| for k in key_options: |
| if c.get(k): |
| out.append(c[k]); break |
| return out |
|
|
| hpo_ids = _ids(phenotypes, ["hpo_id", "hpoId", "id"]) |
| orpha_ids = _ids(diseases, ["orpha", "orpha_code", "orphaCode", "code"]) |
| gene_ids = _ids(genes, ["symbol", "gene", "name"]) |
|
|
| accum = np.zeros(dim, dtype=np.float32) |
| n = 0 |
| hits = {"disease": 0, "phenotype": 0, "gene": 0} |
| misses = {"disease": 0, "phenotype": 0, "gene": 0} |
|
|
| for orpha in orpha_ids: |
| v = bridge.lookup_disease_embedding(str(orpha)) |
| if v is not None and v.shape[0] == dim: |
| accum += weight_disease * v.astype(np.float32) |
| n += weight_disease |
| hits["disease"] += 1 |
| else: |
| misses["disease"] += 1 |
|
|
| for hpo in hpo_ids: |
| v = bridge.lookup_phenotype_embedding(str(hpo)) |
| if v is not None and v.shape[0] == dim: |
| accum += weight_phenotype * v.astype(np.float32) |
| n += weight_phenotype |
| hits["phenotype"] += 1 |
| else: |
| misses["phenotype"] += 1 |
|
|
| for sym in gene_ids: |
| v = bridge.lookup_gene_embedding(str(sym).upper()) |
| if v is not None and v.shape[0] == dim: |
| accum += weight_gene * v.astype(np.float32) |
| n += weight_gene |
| hits["gene"] += 1 |
| else: |
| misses["gene"] += 1 |
|
|
| if n == 0: |
| return np.zeros(dim, dtype=np.float32), "empty" |
|
|
| avg = accum / n |
| avg = _l2_normalize(avg) |
|
|
| n_input = len(hpo_ids) + len(orpha_ids) + len(gene_ids) |
| n_hit = sum(hits.values()) |
| quality = "full" if n_hit == n_input else "partial" |
| return avg, quality |
|
|
|
|
| |
|
|
| _HGT_MODEL = None |
|
|
|
|
| def _try_load_hgt(): |
| global _HGT_MODEL |
| if _HGT_MODEL is not None: |
| return _HGT_MODEL |
| if not os.path.exists(HGT_CKPT): |
| return None |
| try: |
| import torch |
| _HGT_MODEL = torch.load(HGT_CKPT, map_location="cpu", weights_only=False) |
| logger.info(f"Loaded HGT patient encoder from {HGT_CKPT}") |
| return _HGT_MODEL |
| except Exception as e: |
| logger.warning(f"HGT checkpoint exists but failed to load: {e}") |
| return None |
|
|
|
|
| def encode_hgt(snapshot_dict: dict, dim: int = DEFAULT_DIM): |
| """Run the trained HGT patient encoder. Returns None if model unavailable.""" |
| model = _try_load_hgt() |
| if model is None: |
| return None |
| try: |
| |
| |
| if hasattr(model, "encode_patient"): |
| return model.encode_patient(snapshot_dict) |
| return None |
| except Exception as e: |
| logger.error(f"HGT encode failed, falling back to bootstrap: {e}") |
| return None |
|
|
|
|
| def encode(snapshot_dict: dict, *, dim: int = DEFAULT_DIM): |
| """Top-level encoder. Tries HGT, falls back to bootstrap. |
| |
| snapshot_dict expected keys: |
| phenotypes: [{hpo_id, name, ...}, ...] |
| diseases: [{orpha, ...}, ...] (optional β confirmed/probable dx) |
| genes: [{symbol, ...}, ...] |
| """ |
| v = encode_hgt(snapshot_dict, dim=dim) |
| if v is not None: |
| return v, "hgt" |
| vec, qual = encode_bootstrap( |
| phenotypes=snapshot_dict.get("phenotypes", []), |
| diseases=snapshot_dict.get("diseases", []), |
| genes=snapshot_dict.get("genes", []), |
| dim=dim, |
| ) |
| return vec, f"bootstrap_{qual}" |
|
|
|
|
| def encode_patient_space(space) -> tuple: |
| """Convenience: encode a `PatientSpace` object directly.""" |
| snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None |
| if snap is None: |
| |
| diseases = [] |
| for hyp in (getattr(space, "_hypotheses", {}) or {}).values(): |
| if hasattr(hyp, "orpha_code") and hyp.orpha_code: |
| diseases.append({"orpha": hyp.orpha_code}) |
| return encode({ |
| "phenotypes": [], |
| "diseases": diseases, |
| "genes": [], |
| }) |
| return encode({ |
| "phenotypes": snap.phenotypes, |
| "diseases": snap.diagnoses, |
| "genes": snap.genes, |
| }) |
|
|