File size: 7,177 Bytes
089d665 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | """External knowledge: pre-computed fused embeddings + PrimeKG hetero graph.
Loads the 27,686 BioLORD+graph-fused embeddings (3072-d) for diseases,
phenotypes, and genes. These were produced by raras-app's GNN+BioLORD
fusion pipeline on PrimeKG and are dropped into Gemeo to enrich:
- cohort matching: nearest-neighbour over fused embeddings instead
of just BioLORD text similarity
- subgraph extraction: anchor on PrimeKG hetero relations instead of
sparse Aura graph
- reverse phenotyping: similar-disease lookup with semantic+graph signal
This is a drop-in upgrade with NO retraining required.
"""
from __future__ import annotations
import json
import logging
import os
from functools import lru_cache
from typing import Optional
import numpy as np
logger = logging.getLogger("gemeo.external_kg")
# Defaults to the bundled fp16 artifacts in this repo. Override with
# GEMEO_GRAPH_ML_DIR to use a custom location (e.g. dev workstation
# with the full fp64 export from raras-app).
_REPO_DATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
DEFAULT_GRAPH_ML = os.environ.get(
"GEMEO_GRAPH_ML_DIR",
_REPO_DATA if os.path.exists(_REPO_DATA) else "/Users/dimas/raras-app/data/graph-ml",
)
# Choose fp16 fname if available (shipped variant), else fp64 (dev).
_FUSED_FNAME = (
"fused_embeddings_fp16.npz"
if os.path.exists(os.path.join(DEFAULT_GRAPH_ML, "fused_embeddings_fp16.npz"))
else "fused_embeddings.npz"
)
@lru_cache(maxsize=1)
def load_fused_embeddings(graph_ml_dir: str = None) -> dict:
"""Load pre-computed fused 3072-d embeddings + node id maps.
Returns dict with keys:
- disease_emb: np.ndarray (10468, 3072)
- phenotype_emb: np.ndarray (11652, 3072)
- gene_emb: np.ndarray (5566, 3072)
- disease_idx2id: {pos: orpha_code_str}
- phenotype_idx2id: {pos: hpo_id_str}
- gene_idx2id: {pos: hgnc_id_str}
- disease_id2idx, phenotype_id2idx, gene_id2idx: inverse maps
"""
d = graph_ml_dir or DEFAULT_GRAPH_ML
if not os.path.exists(d):
logger.warning(f"graph-ml dir not found: {d}")
return {}
fused_path = os.path.join(d, _FUSED_FNAME)
nid_path = os.path.join(d, "node_ids.json")
if not (os.path.exists(fused_path) and os.path.exists(nid_path)):
logger.warning(f"missing fused embeddings or node_ids in {d}")
return {}
fz = np.load(fused_path)
with open(nid_path) as f:
nids = json.load(f)
out = {}
for kind in ("disease", "phenotype", "gene"):
if kind in fz.files and kind in nids:
emb = fz[kind]
idx2id = {int(k): v for k, v in nids[kind].items()}
id2idx = {v: int(k) for k, v in nids[kind].items()}
out[f"{kind}_emb"] = emb
out[f"{kind}_idx2id"] = idx2id
out[f"{kind}_id2idx"] = id2idx
logger.info(f" loaded {kind}: {emb.shape} embeddings, {len(id2idx)} ids")
return out
def disease_neighbors(orpha_code: str, k: int = 10,
graph_ml_dir: str = None) -> list[tuple[str, float]]:
"""Return k nearest diseases (by orpha) to the given orpha by fused embedding."""
kg = load_fused_embeddings(graph_ml_dir)
if not kg or "disease_emb" not in kg:
return []
emb = kg["disease_emb"]
id2idx = kg["disease_id2idx"]
idx2id = kg["disease_idx2id"]
if str(orpha_code) not in id2idx:
return []
qi = id2idx[str(orpha_code)]
qv = emb[qi]
# Cosine similarity
norms = np.linalg.norm(emb, axis=1) + 1e-9
qn = np.linalg.norm(qv) + 1e-9
sims = (emb @ qv) / (norms * qn)
top = np.argsort(-sims)[1:k + 1] # skip self
return [(idx2id[int(i)], float(sims[i])) for i in top]
def phenotype_for_disease(orpha_code: str, k: int = 20,
graph_ml_dir: str = None) -> list[tuple[str, float]]:
"""Return k phenotypes most similar in fused space to the given disease.
NOTE: This uses cross-modal cosine similarity in fused space; it is a
proxy for true disease→phenotype edges from PrimeKG. For ground-truth
disease→HPO links use Orphanet/HPO directly.
"""
kg = load_fused_embeddings(graph_ml_dir)
if not kg or "disease_emb" not in kg or "phenotype_emb" not in kg:
return []
if str(orpha_code) not in kg["disease_id2idx"]:
return []
qi = kg["disease_id2idx"][str(orpha_code)]
qv = kg["disease_emb"][qi]
pe = kg["phenotype_emb"]
norms = np.linalg.norm(pe, axis=1) + 1e-9
qn = np.linalg.norm(qv) + 1e-9
sims = (pe @ qv) / (norms * qn)
top = np.argsort(-sims)[:k]
return [(kg["phenotype_idx2id"][int(i)], float(sims[i])) for i in top]
def gene_for_disease(orpha_code: str, k: int = 10,
graph_ml_dir: str = None) -> list[tuple[str, float]]:
"""Return k genes most semantically related to the given disease."""
kg = load_fused_embeddings(graph_ml_dir)
if not kg or "disease_emb" not in kg or "gene_emb" not in kg:
return []
if str(orpha_code) not in kg["disease_id2idx"]:
return []
qi = kg["disease_id2idx"][str(orpha_code)]
qv = kg["disease_emb"][qi]
ge = kg["gene_emb"]
norms = np.linalg.norm(ge, axis=1) + 1e-9
qn = np.linalg.norm(qv) + 1e-9
sims = (ge @ qv) / (norms * qn)
top = np.argsort(-sims)[:k]
return [(kg["gene_idx2id"][int(i)], float(sims[i])) for i in top]
def patient_disease_match(patient_emb: np.ndarray, k: int = 10,
graph_ml_dir: str = None) -> list[tuple[str, float]]:
"""Given a 3072-d patient embedding, return k closest diseases.
Useful when build_gemeo wants to re-rank diagnoses against the
full PrimeKG-fused space, not just the local Aura graph.
"""
kg = load_fused_embeddings(graph_ml_dir)
if not kg or "disease_emb" not in kg:
return []
de = kg["disease_emb"]
if patient_emb.shape[-1] != de.shape[1]:
logger.warning(f"dim mismatch: patient {patient_emb.shape}, disease {de.shape}")
return []
norms = np.linalg.norm(de, axis=1) + 1e-9
pn = np.linalg.norm(patient_emb) + 1e-9
sims = (de @ patient_emb) / (norms * pn)
top = np.argsort(-sims)[:k]
return [(kg["disease_idx2id"][int(i)], float(sims[i])) for i in top]
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s")
print("=== sanity check: external_kg ===\n")
kg = load_fused_embeddings()
for k in ("disease_emb", "phenotype_emb", "gene_emb"):
if k in kg:
print(f"{k}: {kg[k].shape}")
else:
print(f"{k}: MISSING")
print()
# ATM neighbors (ORPHA:100)
print("ATM (ORPHA:100) nearest diseases:")
for o, s in disease_neighbors("100", k=8):
print(f" ORPHA:{o:>6} sim={s:.3f}")
print()
print("ATM (ORPHA:100) similar phenotypes (top 10):")
for h, s in phenotype_for_disease("100", k=10):
print(f" HP:{h:>10} sim={s:.3f}")
print()
print("ATM (ORPHA:100) similar genes (top 5):")
for g, s in gene_for_disease("100", k=5):
print(f" HGNC:{g:>10} sim={s:.3f}")
|