"""External knowledge: pre-computed fused embeddings + PrimeKG hetero graph. Loads the 27,686 BioLORD+graph-fused embeddings (3072-d) for diseases, phenotypes, and genes. These were produced by raras-app's GNN+BioLORD fusion pipeline on PrimeKG and are dropped into Gemeo to enrich: - cohort matching: nearest-neighbour over fused embeddings instead of just BioLORD text similarity - subgraph extraction: anchor on PrimeKG hetero relations instead of sparse Aura graph - reverse phenotyping: similar-disease lookup with semantic+graph signal This is a drop-in upgrade with NO retraining required. """ from __future__ import annotations import json import logging import os from functools import lru_cache from typing import Optional import numpy as np logger = logging.getLogger("gemeo.external_kg") # Defaults to the bundled fp16 artifacts in this repo. Override with # GEMEO_GRAPH_ML_DIR to use a custom location (e.g. dev workstation # with the full fp64 export from raras-app). _REPO_DATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") DEFAULT_GRAPH_ML = os.environ.get( "GEMEO_GRAPH_ML_DIR", _REPO_DATA if os.path.exists(_REPO_DATA) else "/Users/dimas/raras-app/data/graph-ml", ) # Choose fp16 fname if available (shipped variant), else fp64 (dev). _FUSED_FNAME = ( "fused_embeddings_fp16.npz" if os.path.exists(os.path.join(DEFAULT_GRAPH_ML, "fused_embeddings_fp16.npz")) else "fused_embeddings.npz" ) @lru_cache(maxsize=1) def load_fused_embeddings(graph_ml_dir: str = None) -> dict: """Load pre-computed fused 3072-d embeddings + node id maps. Returns dict with keys: - disease_emb: np.ndarray (10468, 3072) - phenotype_emb: np.ndarray (11652, 3072) - gene_emb: np.ndarray (5566, 3072) - disease_idx2id: {pos: orpha_code_str} - phenotype_idx2id: {pos: hpo_id_str} - gene_idx2id: {pos: hgnc_id_str} - disease_id2idx, phenotype_id2idx, gene_id2idx: inverse maps """ d = graph_ml_dir or DEFAULT_GRAPH_ML if not os.path.exists(d): logger.warning(f"graph-ml dir not found: {d}") return {} fused_path = os.path.join(d, _FUSED_FNAME) nid_path = os.path.join(d, "node_ids.json") if not (os.path.exists(fused_path) and os.path.exists(nid_path)): logger.warning(f"missing fused embeddings or node_ids in {d}") return {} fz = np.load(fused_path) with open(nid_path) as f: nids = json.load(f) out = {} for kind in ("disease", "phenotype", "gene"): if kind in fz.files and kind in nids: emb = fz[kind] idx2id = {int(k): v for k, v in nids[kind].items()} id2idx = {v: int(k) for k, v in nids[kind].items()} out[f"{kind}_emb"] = emb out[f"{kind}_idx2id"] = idx2id out[f"{kind}_id2idx"] = id2idx logger.info(f" loaded {kind}: {emb.shape} embeddings, {len(id2idx)} ids") return out def disease_neighbors(orpha_code: str, k: int = 10, graph_ml_dir: str = None) -> list[tuple[str, float]]: """Return k nearest diseases (by orpha) to the given orpha by fused embedding.""" kg = load_fused_embeddings(graph_ml_dir) if not kg or "disease_emb" not in kg: return [] emb = kg["disease_emb"] id2idx = kg["disease_id2idx"] idx2id = kg["disease_idx2id"] if str(orpha_code) not in id2idx: return [] qi = id2idx[str(orpha_code)] qv = emb[qi] # Cosine similarity norms = np.linalg.norm(emb, axis=1) + 1e-9 qn = np.linalg.norm(qv) + 1e-9 sims = (emb @ qv) / (norms * qn) top = np.argsort(-sims)[1:k + 1] # skip self return [(idx2id[int(i)], float(sims[i])) for i in top] def phenotype_for_disease(orpha_code: str, k: int = 20, graph_ml_dir: str = None) -> list[tuple[str, float]]: """Return k phenotypes most similar in fused space to the given disease. NOTE: This uses cross-modal cosine similarity in fused space; it is a proxy for true disease→phenotype edges from PrimeKG. For ground-truth disease→HPO links use Orphanet/HPO directly. """ kg = load_fused_embeddings(graph_ml_dir) if not kg or "disease_emb" not in kg or "phenotype_emb" not in kg: return [] if str(orpha_code) not in kg["disease_id2idx"]: return [] qi = kg["disease_id2idx"][str(orpha_code)] qv = kg["disease_emb"][qi] pe = kg["phenotype_emb"] norms = np.linalg.norm(pe, axis=1) + 1e-9 qn = np.linalg.norm(qv) + 1e-9 sims = (pe @ qv) / (norms * qn) top = np.argsort(-sims)[:k] return [(kg["phenotype_idx2id"][int(i)], float(sims[i])) for i in top] def gene_for_disease(orpha_code: str, k: int = 10, graph_ml_dir: str = None) -> list[tuple[str, float]]: """Return k genes most semantically related to the given disease.""" kg = load_fused_embeddings(graph_ml_dir) if not kg or "disease_emb" not in kg or "gene_emb" not in kg: return [] if str(orpha_code) not in kg["disease_id2idx"]: return [] qi = kg["disease_id2idx"][str(orpha_code)] qv = kg["disease_emb"][qi] ge = kg["gene_emb"] norms = np.linalg.norm(ge, axis=1) + 1e-9 qn = np.linalg.norm(qv) + 1e-9 sims = (ge @ qv) / (norms * qn) top = np.argsort(-sims)[:k] return [(kg["gene_idx2id"][int(i)], float(sims[i])) for i in top] def patient_disease_match(patient_emb: np.ndarray, k: int = 10, graph_ml_dir: str = None) -> list[tuple[str, float]]: """Given a 3072-d patient embedding, return k closest diseases. Useful when build_gemeo wants to re-rank diagnoses against the full PrimeKG-fused space, not just the local Aura graph. """ kg = load_fused_embeddings(graph_ml_dir) if not kg or "disease_emb" not in kg: return [] de = kg["disease_emb"] if patient_emb.shape[-1] != de.shape[1]: logger.warning(f"dim mismatch: patient {patient_emb.shape}, disease {de.shape}") return [] norms = np.linalg.norm(de, axis=1) + 1e-9 pn = np.linalg.norm(patient_emb) + 1e-9 sims = (de @ patient_emb) / (norms * pn) top = np.argsort(-sims)[:k] return [(kg["disease_idx2id"][int(i)], float(sims[i])) for i in top] if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") print("=== sanity check: external_kg ===\n") kg = load_fused_embeddings() for k in ("disease_emb", "phenotype_emb", "gene_emb"): if k in kg: print(f"{k}: {kg[k].shape}") else: print(f"{k}: MISSING") print() # ATM neighbors (ORPHA:100) print("ATM (ORPHA:100) nearest diseases:") for o, s in disease_neighbors("100", k=8): print(f" ORPHA:{o:>6} sim={s:.3f}") print() print("ATM (ORPHA:100) similar phenotypes (top 10):") for h, s in phenotype_for_disease("100", k=10): print(f" HP:{h:>10} sim={s:.3f}") print() print("ATM (ORPHA:100) similar genes (top 5):") for g, s in gene_for_disease("100", k=5): print(f" HGNC:{g:>10} sim={s:.3f}")