File size: 7,177 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""External knowledge: pre-computed fused embeddings + PrimeKG hetero graph.

Loads the 27,686 BioLORD+graph-fused embeddings (3072-d) for diseases,
phenotypes, and genes. These were produced by raras-app's GNN+BioLORD
fusion pipeline on PrimeKG and are dropped into Gemeo to enrich:

  - cohort matching: nearest-neighbour over fused embeddings instead
    of just BioLORD text similarity
  - subgraph extraction: anchor on PrimeKG hetero relations instead of
    sparse Aura graph
  - reverse phenotyping: similar-disease lookup with semantic+graph signal

This is a drop-in upgrade with NO retraining required.
"""
from __future__ import annotations
import json
import logging
import os
from functools import lru_cache
from typing import Optional

import numpy as np

logger = logging.getLogger("gemeo.external_kg")

# Defaults to the bundled fp16 artifacts in this repo. Override with
# GEMEO_GRAPH_ML_DIR to use a custom location (e.g. dev workstation
# with the full fp64 export from raras-app).
_REPO_DATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
DEFAULT_GRAPH_ML = os.environ.get(
    "GEMEO_GRAPH_ML_DIR",
    _REPO_DATA if os.path.exists(_REPO_DATA) else "/Users/dimas/raras-app/data/graph-ml",
)
# Choose fp16 fname if available (shipped variant), else fp64 (dev).
_FUSED_FNAME = (
    "fused_embeddings_fp16.npz"
    if os.path.exists(os.path.join(DEFAULT_GRAPH_ML, "fused_embeddings_fp16.npz"))
    else "fused_embeddings.npz"
)


@lru_cache(maxsize=1)
def load_fused_embeddings(graph_ml_dir: str = None) -> dict:
    """Load pre-computed fused 3072-d embeddings + node id maps.

    Returns dict with keys:
      - disease_emb: np.ndarray (10468, 3072)
      - phenotype_emb: np.ndarray (11652, 3072)
      - gene_emb: np.ndarray (5566, 3072)
      - disease_idx2id: {pos: orpha_code_str}
      - phenotype_idx2id: {pos: hpo_id_str}
      - gene_idx2id: {pos: hgnc_id_str}
      - disease_id2idx, phenotype_id2idx, gene_id2idx: inverse maps
    """
    d = graph_ml_dir or DEFAULT_GRAPH_ML
    if not os.path.exists(d):
        logger.warning(f"graph-ml dir not found: {d}")
        return {}

    fused_path = os.path.join(d, _FUSED_FNAME)
    nid_path = os.path.join(d, "node_ids.json")

    if not (os.path.exists(fused_path) and os.path.exists(nid_path)):
        logger.warning(f"missing fused embeddings or node_ids in {d}")
        return {}

    fz = np.load(fused_path)
    with open(nid_path) as f:
        nids = json.load(f)

    out = {}
    for kind in ("disease", "phenotype", "gene"):
        if kind in fz.files and kind in nids:
            emb = fz[kind]
            idx2id = {int(k): v for k, v in nids[kind].items()}
            id2idx = {v: int(k) for k, v in nids[kind].items()}
            out[f"{kind}_emb"] = emb
            out[f"{kind}_idx2id"] = idx2id
            out[f"{kind}_id2idx"] = id2idx
            logger.info(f"  loaded {kind}: {emb.shape} embeddings, {len(id2idx)} ids")

    return out


def disease_neighbors(orpha_code: str, k: int = 10,
                      graph_ml_dir: str = None) -> list[tuple[str, float]]:
    """Return k nearest diseases (by orpha) to the given orpha by fused embedding."""
    kg = load_fused_embeddings(graph_ml_dir)
    if not kg or "disease_emb" not in kg:
        return []
    emb = kg["disease_emb"]
    id2idx = kg["disease_id2idx"]
    idx2id = kg["disease_idx2id"]
    if str(orpha_code) not in id2idx:
        return []
    qi = id2idx[str(orpha_code)]
    qv = emb[qi]
    # Cosine similarity
    norms = np.linalg.norm(emb, axis=1) + 1e-9
    qn = np.linalg.norm(qv) + 1e-9
    sims = (emb @ qv) / (norms * qn)
    top = np.argsort(-sims)[1:k + 1]  # skip self
    return [(idx2id[int(i)], float(sims[i])) for i in top]


def phenotype_for_disease(orpha_code: str, k: int = 20,
                          graph_ml_dir: str = None) -> list[tuple[str, float]]:
    """Return k phenotypes most similar in fused space to the given disease.

    NOTE: This uses cross-modal cosine similarity in fused space; it is a
    proxy for true disease→phenotype edges from PrimeKG. For ground-truth
    disease→HPO links use Orphanet/HPO directly.
    """
    kg = load_fused_embeddings(graph_ml_dir)
    if not kg or "disease_emb" not in kg or "phenotype_emb" not in kg:
        return []
    if str(orpha_code) not in kg["disease_id2idx"]:
        return []
    qi = kg["disease_id2idx"][str(orpha_code)]
    qv = kg["disease_emb"][qi]
    pe = kg["phenotype_emb"]
    norms = np.linalg.norm(pe, axis=1) + 1e-9
    qn = np.linalg.norm(qv) + 1e-9
    sims = (pe @ qv) / (norms * qn)
    top = np.argsort(-sims)[:k]
    return [(kg["phenotype_idx2id"][int(i)], float(sims[i])) for i in top]


def gene_for_disease(orpha_code: str, k: int = 10,
                     graph_ml_dir: str = None) -> list[tuple[str, float]]:
    """Return k genes most semantically related to the given disease."""
    kg = load_fused_embeddings(graph_ml_dir)
    if not kg or "disease_emb" not in kg or "gene_emb" not in kg:
        return []
    if str(orpha_code) not in kg["disease_id2idx"]:
        return []
    qi = kg["disease_id2idx"][str(orpha_code)]
    qv = kg["disease_emb"][qi]
    ge = kg["gene_emb"]
    norms = np.linalg.norm(ge, axis=1) + 1e-9
    qn = np.linalg.norm(qv) + 1e-9
    sims = (ge @ qv) / (norms * qn)
    top = np.argsort(-sims)[:k]
    return [(kg["gene_idx2id"][int(i)], float(sims[i])) for i in top]


def patient_disease_match(patient_emb: np.ndarray, k: int = 10,
                          graph_ml_dir: str = None) -> list[tuple[str, float]]:
    """Given a 3072-d patient embedding, return k closest diseases.

    Useful when build_gemeo wants to re-rank diagnoses against the
    full PrimeKG-fused space, not just the local Aura graph.
    """
    kg = load_fused_embeddings(graph_ml_dir)
    if not kg or "disease_emb" not in kg:
        return []
    de = kg["disease_emb"]
    if patient_emb.shape[-1] != de.shape[1]:
        logger.warning(f"dim mismatch: patient {patient_emb.shape}, disease {de.shape}")
        return []
    norms = np.linalg.norm(de, axis=1) + 1e-9
    pn = np.linalg.norm(patient_emb) + 1e-9
    sims = (de @ patient_emb) / (norms * pn)
    top = np.argsort(-sims)[:k]
    return [(kg["disease_idx2id"][int(i)], float(sims[i])) for i in top]


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s")
    print("=== sanity check: external_kg ===\n")
    kg = load_fused_embeddings()
    for k in ("disease_emb", "phenotype_emb", "gene_emb"):
        if k in kg:
            print(f"{k}: {kg[k].shape}")
        else:
            print(f"{k}: MISSING")
    print()

    # ATM neighbors (ORPHA:100)
    print("ATM (ORPHA:100) nearest diseases:")
    for o, s in disease_neighbors("100", k=8):
        print(f"  ORPHA:{o:>6}  sim={s:.3f}")
    print()

    print("ATM (ORPHA:100) similar phenotypes (top 10):")
    for h, s in phenotype_for_disease("100", k=10):
        print(f"  HP:{h:>10}  sim={s:.3f}")
    print()

    print("ATM (ORPHA:100) similar genes (top 5):")
    for g, s in gene_for_disease("100", k=5):
        print(f"  HGNC:{g:>10}  sim={s:.3f}")