File size: 6,192 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""Bridge to raras-app graph-ml artifacts.

raras-app trains PhenoGNet monthly and writes:
    /Users/dimas/raras-app/data/graph-ml/biolord_embeddings.npz   (768-dim, init)
    /Users/dimas/raras-app/data/graph-ml/graph_embeddings.npz     (64-dim, contrastive GNN)
    /Users/dimas/raras-app/data/graph-ml/fused_embeddings.npz     (3072-dim, final, Neo4j-indexed)
    /Users/dimas/raras-app/data/graph-ml/node_ids.json            (index → ORPHA/HPO/HGNC)
    /Users/dimas/raras-app/data/graph-ml/hetero_graph.json        (edges adjacency)

Gemeo loads these read-only. We never retrain inside swarm-py — retrain
runs in raras-app's `retrain-scheduled.sh` cron. Phase-2 GNN training
lives in `gemeo/train/` and writes its own artifacts under `gemeo/artifacts/`.
"""
from __future__ import annotations
import os
import json
import logging
from functools import lru_cache
from typing import Optional

logger = logging.getLogger("gemeo.bridge")

# Default location — env override allowed.
# Priority:
#   1. RARAS_APP_GRAPH_ML env (dev / custom location with full fp64 artifacts)
#   2. ./gemeo/data (the fp16 bundle shipped with the repo, ~54 MB)
#   3. /Users/dimas/raras-app/data/graph-ml (dev fallback for the original fp64)
_REPO_DATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
RARAS_APP_GRAPH_ML = os.environ.get(
    "RARAS_APP_GRAPH_ML",
    _REPO_DATA if os.path.exists(_REPO_DATA) else "/Users/dimas/raras-app/data/graph-ml",
)

# Pick the fp16 file if it exists (shipped variant), else the fp64 one (dev).
FUSED_FNAME = (
    "fused_embeddings_fp16.npz"
    if os.path.exists(os.path.join(RARAS_APP_GRAPH_ML, "fused_embeddings_fp16.npz"))
    else "fused_embeddings.npz"
)
logger.info(f"gemeo.bridge: data dir={RARAS_APP_GRAPH_ML} fused={FUSED_FNAME}")

# In-process artifact cache — np.load(npz) is mmap-ish; we hold refs to keep arrays valid
_ARTIFACT_CACHE: dict = {}


def _path(name: str) -> str:
    return os.path.join(RARAS_APP_GRAPH_ML, name)


def has_raras_artifacts() -> bool:
    return all(
        os.path.exists(_path(f))
        for f in (FUSED_FNAME, "node_ids.json")
    )


@lru_cache(maxsize=1)
def load_node_ids() -> dict:
    """Returns {'disease': [orpha,...], 'phenotype': [hpo,...], 'gene': [symbol,...]}."""
    p = _path("node_ids.json")
    if not os.path.exists(p):
        logger.warning(f"node_ids.json not found at {p}")
        return {"disease": [], "phenotype": [], "gene": []}
    with open(p) as f:
        return json.load(f)


@lru_cache(maxsize=1)
def load_node_index() -> dict:
    """Inverted: {'disease': {orpha: idx}, 'phenotype': {hpo: idx}, 'gene': {symbol: idx}}."""
    ids = load_node_ids()
    return {
        kind: {nid: i for i, nid in enumerate(lst)}
        for kind, lst in ids.items()
    }


def load_fused_embeddings():
    """Returns dict: {'disease': np.ndarray, 'phenotype': np.ndarray, 'gene': np.ndarray} (3072-dim).

    Auto-detects fp16-quantized variant (`fused_embeddings_fp16.npz`,
    ~41 MB) or the original fp64 (`fused_embeddings.npz`, ~649 MB).
    The fp16 version is what ships in `gemeo/data/`; fp64 only exists
    in the dev workstation export from raras-app.
    """
    if "fused" in _ARTIFACT_CACHE:
        return _ARTIFACT_CACHE["fused"]
    p = _path(FUSED_FNAME)
    if not os.path.exists(p):
        logger.warning(f"{FUSED_FNAME} not found at {p}")
        return None
    try:
        import numpy as np
        npz = np.load(p)
        out = {k: npz[k] for k in npz.files}
        _ARTIFACT_CACHE["fused"] = out
        return out
    except Exception as e:
        logger.error(f"Failed to load fused embeddings: {e}")
        return None


def load_graph_embeddings():
    """64-dim PhenoGNet embeddings (lighter, for fast in-memory ops)."""
    if "graph" in _ARTIFACT_CACHE:
        return _ARTIFACT_CACHE["graph"]
    p = _path("graph_embeddings.npz")
    if not os.path.exists(p):
        return None
    try:
        import numpy as np
        npz = np.load(p)
        out = {k: npz[k] for k in npz.files}
        _ARTIFACT_CACHE["graph"] = out
        return out
    except Exception as e:
        logger.error(f"Failed to load graph embeddings: {e}")
        return None


@lru_cache(maxsize=1)
def load_hetero_graph() -> Optional[dict]:
    """The exported heterogeneous graph — node counts and edges by relation."""
    p = _path("hetero_graph.json")
    if not os.path.exists(p):
        return None
    try:
        with open(p) as f:
            return json.load(f)
    except Exception as e:
        logger.error(f"Failed to load hetero_graph.json: {e}")
        return None


def lookup_disease_embedding(orpha: str, kind: str = "fused"):
    emb = load_fused_embeddings() if kind == "fused" else load_graph_embeddings()
    if emb is None:
        return None
    idx = load_node_index().get("disease", {}).get(orpha)
    if idx is None:
        return None
    return emb["disease"][idx]


def lookup_phenotype_embedding(hpo: str, kind: str = "fused"):
    emb = load_fused_embeddings() if kind == "fused" else load_graph_embeddings()
    if emb is None:
        return None
    idx = load_node_index().get("phenotype", {}).get(hpo)
    if idx is None:
        return None
    return emb["phenotype"][idx]


def lookup_gene_embedding(symbol: str, kind: str = "fused"):
    emb = load_fused_embeddings() if kind == "fused" else load_graph_embeddings()
    if emb is None:
        return None
    idx = load_node_index().get("gene", {}).get(symbol)
    if idx is None:
        return None
    return emb["gene"][idx]


def stats() -> dict:
    """Diagnostic info for /api/gemeo/health."""
    out = {
        "graph_ml_dir": RARAS_APP_GRAPH_ML,
        "available": has_raras_artifacts(),
        "fused_loaded": "fused" in _ARTIFACT_CACHE,
        "graph_loaded": "graph" in _ARTIFACT_CACHE,
    }
    ids = load_node_ids()
    out["n_diseases"] = len(ids.get("disease", []))
    out["n_phenotypes"] = len(ids.get("phenotype", []))
    out["n_genes"] = len(ids.get("gene", []))
    fused = load_fused_embeddings()
    if fused is not None:
        out["fused_dim"] = int(fused["disease"].shape[1]) if "disease" in fused else 0
    return out