gemeo-twin-stack / src /gemeo /external_kg.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""External knowledge: pre-computed fused embeddings + PrimeKG hetero graph.
Loads the 27,686 BioLORD+graph-fused embeddings (3072-d) for diseases,
phenotypes, and genes. These were produced by raras-app's GNN+BioLORD
fusion pipeline on PrimeKG and are dropped into Gemeo to enrich:
- cohort matching: nearest-neighbour over fused embeddings instead
of just BioLORD text similarity
- subgraph extraction: anchor on PrimeKG hetero relations instead of
sparse Aura graph
- reverse phenotyping: similar-disease lookup with semantic+graph signal
This is a drop-in upgrade with NO retraining required.
"""
from __future__ import annotations
import json
import logging
import os
from functools import lru_cache
from typing import Optional
import numpy as np
logger = logging.getLogger("gemeo.external_kg")
# Defaults to the bundled fp16 artifacts in this repo. Override with
# GEMEO_GRAPH_ML_DIR to use a custom location (e.g. dev workstation
# with the full fp64 export from raras-app).
_REPO_DATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
DEFAULT_GRAPH_ML = os.environ.get(
"GEMEO_GRAPH_ML_DIR",
_REPO_DATA if os.path.exists(_REPO_DATA) else "/Users/dimas/raras-app/data/graph-ml",
)
# Choose fp16 fname if available (shipped variant), else fp64 (dev).
_FUSED_FNAME = (
"fused_embeddings_fp16.npz"
if os.path.exists(os.path.join(DEFAULT_GRAPH_ML, "fused_embeddings_fp16.npz"))
else "fused_embeddings.npz"
)
@lru_cache(maxsize=1)
def load_fused_embeddings(graph_ml_dir: str = None) -> dict:
"""Load pre-computed fused 3072-d embeddings + node id maps.
Returns dict with keys:
- disease_emb: np.ndarray (10468, 3072)
- phenotype_emb: np.ndarray (11652, 3072)
- gene_emb: np.ndarray (5566, 3072)
- disease_idx2id: {pos: orpha_code_str}
- phenotype_idx2id: {pos: hpo_id_str}
- gene_idx2id: {pos: hgnc_id_str}
- disease_id2idx, phenotype_id2idx, gene_id2idx: inverse maps
"""
d = graph_ml_dir or DEFAULT_GRAPH_ML
if not os.path.exists(d):
logger.warning(f"graph-ml dir not found: {d}")
return {}
fused_path = os.path.join(d, _FUSED_FNAME)
nid_path = os.path.join(d, "node_ids.json")
if not (os.path.exists(fused_path) and os.path.exists(nid_path)):
logger.warning(f"missing fused embeddings or node_ids in {d}")
return {}
fz = np.load(fused_path)
with open(nid_path) as f:
nids = json.load(f)
out = {}
for kind in ("disease", "phenotype", "gene"):
if kind in fz.files and kind in nids:
emb = fz[kind]
idx2id = {int(k): v for k, v in nids[kind].items()}
id2idx = {v: int(k) for k, v in nids[kind].items()}
out[f"{kind}_emb"] = emb
out[f"{kind}_idx2id"] = idx2id
out[f"{kind}_id2idx"] = id2idx
logger.info(f" loaded {kind}: {emb.shape} embeddings, {len(id2idx)} ids")
return out
def disease_neighbors(orpha_code: str, k: int = 10,
graph_ml_dir: str = None) -> list[tuple[str, float]]:
"""Return k nearest diseases (by orpha) to the given orpha by fused embedding."""
kg = load_fused_embeddings(graph_ml_dir)
if not kg or "disease_emb" not in kg:
return []
emb = kg["disease_emb"]
id2idx = kg["disease_id2idx"]
idx2id = kg["disease_idx2id"]
if str(orpha_code) not in id2idx:
return []
qi = id2idx[str(orpha_code)]
qv = emb[qi]
# Cosine similarity
norms = np.linalg.norm(emb, axis=1) + 1e-9
qn = np.linalg.norm(qv) + 1e-9
sims = (emb @ qv) / (norms * qn)
top = np.argsort(-sims)[1:k + 1] # skip self
return [(idx2id[int(i)], float(sims[i])) for i in top]
def phenotype_for_disease(orpha_code: str, k: int = 20,
graph_ml_dir: str = None) -> list[tuple[str, float]]:
"""Return k phenotypes most similar in fused space to the given disease.
NOTE: This uses cross-modal cosine similarity in fused space; it is a
proxy for true disease→phenotype edges from PrimeKG. For ground-truth
disease→HPO links use Orphanet/HPO directly.
"""
kg = load_fused_embeddings(graph_ml_dir)
if not kg or "disease_emb" not in kg or "phenotype_emb" not in kg:
return []
if str(orpha_code) not in kg["disease_id2idx"]:
return []
qi = kg["disease_id2idx"][str(orpha_code)]
qv = kg["disease_emb"][qi]
pe = kg["phenotype_emb"]
norms = np.linalg.norm(pe, axis=1) + 1e-9
qn = np.linalg.norm(qv) + 1e-9
sims = (pe @ qv) / (norms * qn)
top = np.argsort(-sims)[:k]
return [(kg["phenotype_idx2id"][int(i)], float(sims[i])) for i in top]
def gene_for_disease(orpha_code: str, k: int = 10,
graph_ml_dir: str = None) -> list[tuple[str, float]]:
"""Return k genes most semantically related to the given disease."""
kg = load_fused_embeddings(graph_ml_dir)
if not kg or "disease_emb" not in kg or "gene_emb" not in kg:
return []
if str(orpha_code) not in kg["disease_id2idx"]:
return []
qi = kg["disease_id2idx"][str(orpha_code)]
qv = kg["disease_emb"][qi]
ge = kg["gene_emb"]
norms = np.linalg.norm(ge, axis=1) + 1e-9
qn = np.linalg.norm(qv) + 1e-9
sims = (ge @ qv) / (norms * qn)
top = np.argsort(-sims)[:k]
return [(kg["gene_idx2id"][int(i)], float(sims[i])) for i in top]
def patient_disease_match(patient_emb: np.ndarray, k: int = 10,
graph_ml_dir: str = None) -> list[tuple[str, float]]:
"""Given a 3072-d patient embedding, return k closest diseases.
Useful when build_gemeo wants to re-rank diagnoses against the
full PrimeKG-fused space, not just the local Aura graph.
"""
kg = load_fused_embeddings(graph_ml_dir)
if not kg or "disease_emb" not in kg:
return []
de = kg["disease_emb"]
if patient_emb.shape[-1] != de.shape[1]:
logger.warning(f"dim mismatch: patient {patient_emb.shape}, disease {de.shape}")
return []
norms = np.linalg.norm(de, axis=1) + 1e-9
pn = np.linalg.norm(patient_emb) + 1e-9
sims = (de @ patient_emb) / (norms * pn)
top = np.argsort(-sims)[:k]
return [(kg["disease_idx2id"][int(i)], float(sims[i])) for i in top]
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s")
print("=== sanity check: external_kg ===\n")
kg = load_fused_embeddings()
for k in ("disease_emb", "phenotype_emb", "gene_emb"):
if k in kg:
print(f"{k}: {kg[k].shape}")
else:
print(f"{k}: MISSING")
print()
# ATM neighbors (ORPHA:100)
print("ATM (ORPHA:100) nearest diseases:")
for o, s in disease_neighbors("100", k=8):
print(f" ORPHA:{o:>6} sim={s:.3f}")
print()
print("ATM (ORPHA:100) similar phenotypes (top 10):")
for h, s in phenotype_for_disease("100", k=10):
print(f" HP:{h:>10} sim={s:.3f}")
print()
print("ATM (ORPHA:100) similar genes (top 5):")
for g, s in gene_for_disease("100", k=5):
print(f" HGNC:{g:>10} sim={s:.3f}")