| """External knowledge: pre-computed fused embeddings + PrimeKG hetero graph. |
| |
| Loads the 27,686 BioLORD+graph-fused embeddings (3072-d) for diseases, |
| phenotypes, and genes. These were produced by raras-app's GNN+BioLORD |
| fusion pipeline on PrimeKG and are dropped into Gemeo to enrich: |
| |
| - cohort matching: nearest-neighbour over fused embeddings instead |
| of just BioLORD text similarity |
| - subgraph extraction: anchor on PrimeKG hetero relations instead of |
| sparse Aura graph |
| - reverse phenotyping: similar-disease lookup with semantic+graph signal |
| |
| This is a drop-in upgrade with NO retraining required. |
| """ |
| from __future__ import annotations |
| import json |
| import logging |
| import os |
| from functools import lru_cache |
| from typing import Optional |
|
|
| import numpy as np |
|
|
| logger = logging.getLogger("gemeo.external_kg") |
|
|
| |
| |
| |
| _REPO_DATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data") |
| DEFAULT_GRAPH_ML = os.environ.get( |
| "GEMEO_GRAPH_ML_DIR", |
| _REPO_DATA if os.path.exists(_REPO_DATA) else "/Users/dimas/raras-app/data/graph-ml", |
| ) |
| |
| _FUSED_FNAME = ( |
| "fused_embeddings_fp16.npz" |
| if os.path.exists(os.path.join(DEFAULT_GRAPH_ML, "fused_embeddings_fp16.npz")) |
| else "fused_embeddings.npz" |
| ) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def load_fused_embeddings(graph_ml_dir: str = None) -> dict: |
| """Load pre-computed fused 3072-d embeddings + node id maps. |
| |
| Returns dict with keys: |
| - disease_emb: np.ndarray (10468, 3072) |
| - phenotype_emb: np.ndarray (11652, 3072) |
| - gene_emb: np.ndarray (5566, 3072) |
| - disease_idx2id: {pos: orpha_code_str} |
| - phenotype_idx2id: {pos: hpo_id_str} |
| - gene_idx2id: {pos: hgnc_id_str} |
| - disease_id2idx, phenotype_id2idx, gene_id2idx: inverse maps |
| """ |
| d = graph_ml_dir or DEFAULT_GRAPH_ML |
| if not os.path.exists(d): |
| logger.warning(f"graph-ml dir not found: {d}") |
| return {} |
|
|
| fused_path = os.path.join(d, _FUSED_FNAME) |
| nid_path = os.path.join(d, "node_ids.json") |
|
|
| if not (os.path.exists(fused_path) and os.path.exists(nid_path)): |
| logger.warning(f"missing fused embeddings or node_ids in {d}") |
| return {} |
|
|
| fz = np.load(fused_path) |
| with open(nid_path) as f: |
| nids = json.load(f) |
|
|
| out = {} |
| for kind in ("disease", "phenotype", "gene"): |
| if kind in fz.files and kind in nids: |
| emb = fz[kind] |
| idx2id = {int(k): v for k, v in nids[kind].items()} |
| id2idx = {v: int(k) for k, v in nids[kind].items()} |
| out[f"{kind}_emb"] = emb |
| out[f"{kind}_idx2id"] = idx2id |
| out[f"{kind}_id2idx"] = id2idx |
| logger.info(f" loaded {kind}: {emb.shape} embeddings, {len(id2idx)} ids") |
|
|
| return out |
|
|
|
|
| def disease_neighbors(orpha_code: str, k: int = 10, |
| graph_ml_dir: str = None) -> list[tuple[str, float]]: |
| """Return k nearest diseases (by orpha) to the given orpha by fused embedding.""" |
| kg = load_fused_embeddings(graph_ml_dir) |
| if not kg or "disease_emb" not in kg: |
| return [] |
| emb = kg["disease_emb"] |
| id2idx = kg["disease_id2idx"] |
| idx2id = kg["disease_idx2id"] |
| if str(orpha_code) not in id2idx: |
| return [] |
| qi = id2idx[str(orpha_code)] |
| qv = emb[qi] |
| |
| norms = np.linalg.norm(emb, axis=1) + 1e-9 |
| qn = np.linalg.norm(qv) + 1e-9 |
| sims = (emb @ qv) / (norms * qn) |
| top = np.argsort(-sims)[1:k + 1] |
| return [(idx2id[int(i)], float(sims[i])) for i in top] |
|
|
|
|
| def phenotype_for_disease(orpha_code: str, k: int = 20, |
| graph_ml_dir: str = None) -> list[tuple[str, float]]: |
| """Return k phenotypes most similar in fused space to the given disease. |
| |
| NOTE: This uses cross-modal cosine similarity in fused space; it is a |
| proxy for true disease→phenotype edges from PrimeKG. For ground-truth |
| disease→HPO links use Orphanet/HPO directly. |
| """ |
| kg = load_fused_embeddings(graph_ml_dir) |
| if not kg or "disease_emb" not in kg or "phenotype_emb" not in kg: |
| return [] |
| if str(orpha_code) not in kg["disease_id2idx"]: |
| return [] |
| qi = kg["disease_id2idx"][str(orpha_code)] |
| qv = kg["disease_emb"][qi] |
| pe = kg["phenotype_emb"] |
| norms = np.linalg.norm(pe, axis=1) + 1e-9 |
| qn = np.linalg.norm(qv) + 1e-9 |
| sims = (pe @ qv) / (norms * qn) |
| top = np.argsort(-sims)[:k] |
| return [(kg["phenotype_idx2id"][int(i)], float(sims[i])) for i in top] |
|
|
|
|
| def gene_for_disease(orpha_code: str, k: int = 10, |
| graph_ml_dir: str = None) -> list[tuple[str, float]]: |
| """Return k genes most semantically related to the given disease.""" |
| kg = load_fused_embeddings(graph_ml_dir) |
| if not kg or "disease_emb" not in kg or "gene_emb" not in kg: |
| return [] |
| if str(orpha_code) not in kg["disease_id2idx"]: |
| return [] |
| qi = kg["disease_id2idx"][str(orpha_code)] |
| qv = kg["disease_emb"][qi] |
| ge = kg["gene_emb"] |
| norms = np.linalg.norm(ge, axis=1) + 1e-9 |
| qn = np.linalg.norm(qv) + 1e-9 |
| sims = (ge @ qv) / (norms * qn) |
| top = np.argsort(-sims)[:k] |
| return [(kg["gene_idx2id"][int(i)], float(sims[i])) for i in top] |
|
|
|
|
| def patient_disease_match(patient_emb: np.ndarray, k: int = 10, |
| graph_ml_dir: str = None) -> list[tuple[str, float]]: |
| """Given a 3072-d patient embedding, return k closest diseases. |
| |
| Useful when build_gemeo wants to re-rank diagnoses against the |
| full PrimeKG-fused space, not just the local Aura graph. |
| """ |
| kg = load_fused_embeddings(graph_ml_dir) |
| if not kg or "disease_emb" not in kg: |
| return [] |
| de = kg["disease_emb"] |
| if patient_emb.shape[-1] != de.shape[1]: |
| logger.warning(f"dim mismatch: patient {patient_emb.shape}, disease {de.shape}") |
| return [] |
| norms = np.linalg.norm(de, axis=1) + 1e-9 |
| pn = np.linalg.norm(patient_emb) + 1e-9 |
| sims = (de @ patient_emb) / (norms * pn) |
| top = np.argsort(-sims)[:k] |
| return [(kg["disease_idx2id"][int(i)], float(sims[i])) for i in top] |
|
|
|
|
| if __name__ == "__main__": |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") |
| print("=== sanity check: external_kg ===\n") |
| kg = load_fused_embeddings() |
| for k in ("disease_emb", "phenotype_emb", "gene_emb"): |
| if k in kg: |
| print(f"{k}: {kg[k].shape}") |
| else: |
| print(f"{k}: MISSING") |
| print() |
|
|
| |
| print("ATM (ORPHA:100) nearest diseases:") |
| for o, s in disease_neighbors("100", k=8): |
| print(f" ORPHA:{o:>6} sim={s:.3f}") |
| print() |
|
|
| print("ATM (ORPHA:100) similar phenotypes (top 10):") |
| for h, s in phenotype_for_disease("100", k=10): |
| print(f" HP:{h:>10} sim={s:.3f}") |
| print() |
|
|
| print("ATM (ORPHA:100) similar genes (top 5):") |
| for g, s in gene_for_disease("100", k=5): |
| print(f" HGNC:{g:>10} sim={s:.3f}") |
|
|