"""PatientEncoder — turn a ClinicalSnapshot into a vector. Strategy: 1. **Bootstrap (works TODAY)**: weighted mean of pre-computed embeddings from raras-app graph-ml — disease (×2), phenotype (×1), gene (×1). This matches what raras-app does in `generate-patient-embeddings.mjs`, so embeddings live in the same space as `Patient.embedding` in Neo4j and `/api/graph/similar-patients`. 2. **HGT (Phase 2, gemeo/train/hgt.py)**: trained heterogeneous graph transformer that produces patient embeddings via attention over the patient's HPO+Gene+Lab subgraph. Replaces step (1) when `gemeo/artifacts/hgt_patient_encoder.pt` exists. The fallback chain is deterministic — `encode()` always returns a vector (even if zeros, signalled by `quality='empty'`). """ from __future__ import annotations import logging import os from typing import Optional from . import bridge logger = logging.getLogger("gemeo.encoder") DEFAULT_DIM = 3072 # matches raras-app fused embedding & Neo4j vector index HGT_CKPT = os.environ.get( "GEMEO_HGT_CKPT", os.path.join(os.path.dirname(__file__), "artifacts", "hgt_patient_encoder.pt"), ) def _l2_normalize(vec): import numpy as np n = np.linalg.norm(vec) if n < 1e-9: return vec return vec / n def encode_bootstrap( phenotypes: list, diseases: list, genes: list, *, dim: int = DEFAULT_DIM, weight_disease: float = 2.0, weight_phenotype: float = 1.0, weight_gene: float = 1.0, ): """Aggregate raras-app fused embeddings — same space as Neo4j Patient.embedding. Args: phenotypes: list of HPO ids ["HP:0001250", ...] or dicts {"hpo_id": ...} diseases: list of ORPHA codes ["79253", ...] or dicts {"orpha": ...} genes: list of HGNC symbols ["GBA", ...] or dicts {"symbol": ...} Returns: (vector: np.ndarray (dim,), quality: str) quality ∈ {"empty", "partial", "full"} """ import numpy as np def _ids(coll, key_options): out = [] for c in coll or []: if isinstance(c, str): out.append(c) elif isinstance(c, dict): for k in key_options: if c.get(k): out.append(c[k]); break return out hpo_ids = _ids(phenotypes, ["hpo_id", "hpoId", "id"]) orpha_ids = _ids(diseases, ["orpha", "orpha_code", "orphaCode", "code"]) gene_ids = _ids(genes, ["symbol", "gene", "name"]) accum = np.zeros(dim, dtype=np.float32) n = 0 hits = {"disease": 0, "phenotype": 0, "gene": 0} misses = {"disease": 0, "phenotype": 0, "gene": 0} for orpha in orpha_ids: v = bridge.lookup_disease_embedding(str(orpha)) if v is not None and v.shape[0] == dim: accum += weight_disease * v.astype(np.float32) n += weight_disease hits["disease"] += 1 else: misses["disease"] += 1 for hpo in hpo_ids: v = bridge.lookup_phenotype_embedding(str(hpo)) if v is not None and v.shape[0] == dim: accum += weight_phenotype * v.astype(np.float32) n += weight_phenotype hits["phenotype"] += 1 else: misses["phenotype"] += 1 for sym in gene_ids: v = bridge.lookup_gene_embedding(str(sym).upper()) if v is not None and v.shape[0] == dim: accum += weight_gene * v.astype(np.float32) n += weight_gene hits["gene"] += 1 else: misses["gene"] += 1 if n == 0: return np.zeros(dim, dtype=np.float32), "empty" avg = accum / n avg = _l2_normalize(avg) n_input = len(hpo_ids) + len(orpha_ids) + len(gene_ids) n_hit = sum(hits.values()) quality = "full" if n_hit == n_input else "partial" return avg, quality # ─── HGT slot (Phase 2) ──────────────────────────────────────────────────── _HGT_MODEL = None def _try_load_hgt(): global _HGT_MODEL if _HGT_MODEL is not None: return _HGT_MODEL if not os.path.exists(HGT_CKPT): return None try: import torch _HGT_MODEL = torch.load(HGT_CKPT, map_location="cpu", weights_only=False) logger.info(f"Loaded HGT patient encoder from {HGT_CKPT}") return _HGT_MODEL except Exception as e: logger.warning(f"HGT checkpoint exists but failed to load: {e}") return None def encode_hgt(snapshot_dict: dict, dim: int = DEFAULT_DIM): """Run the trained HGT patient encoder. Returns None if model unavailable.""" model = _try_load_hgt() if model is None: return None try: # The actual forward signature is defined by gemeo/train/hgt.py. # We pass a structured dict with keys: phenotypes, genes, labs, diseases. if hasattr(model, "encode_patient"): return model.encode_patient(snapshot_dict) return None except Exception as e: logger.error(f"HGT encode failed, falling back to bootstrap: {e}") return None def encode(snapshot_dict: dict, *, dim: int = DEFAULT_DIM): """Top-level encoder. Tries HGT, falls back to bootstrap. snapshot_dict expected keys: phenotypes: [{hpo_id, name, ...}, ...] diseases: [{orpha, ...}, ...] (optional — confirmed/probable dx) genes: [{symbol, ...}, ...] """ v = encode_hgt(snapshot_dict, dim=dim) if v is not None: return v, "hgt" vec, qual = encode_bootstrap( phenotypes=snapshot_dict.get("phenotypes", []), diseases=snapshot_dict.get("diseases", []), genes=snapshot_dict.get("genes", []), dim=dim, ) return vec, f"bootstrap_{qual}" def encode_patient_space(space) -> tuple: """Convenience: encode a `PatientSpace` object directly.""" snap = space.get_current_snapshot() if hasattr(space, "get_current_snapshot") else None if snap is None: # Fallback: no snapshots, build from hypotheses + recent events diseases = [] for hyp in (getattr(space, "_hypotheses", {}) or {}).values(): if hasattr(hyp, "orpha_code") and hyp.orpha_code: diseases.append({"orpha": hyp.orpha_code}) return encode({ "phenotypes": [], "diseases": diseases, "genes": [], }) return encode({ "phenotypes": snap.phenotypes, "diseases": snap.diagnoses, "genes": snap.genes, })