File size: 6,192 Bytes
089d665 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | """Bridge to raras-app graph-ml artifacts.
raras-app trains PhenoGNet monthly and writes:
/Users/dimas/raras-app/data/graph-ml/biolord_embeddings.npz (768-dim, init)
/Users/dimas/raras-app/data/graph-ml/graph_embeddings.npz (64-dim, contrastive GNN)
/Users/dimas/raras-app/data/graph-ml/fused_embeddings.npz (3072-dim, final, Neo4j-indexed)
/Users/dimas/raras-app/data/graph-ml/node_ids.json (index → ORPHA/HPO/HGNC)
/Users/dimas/raras-app/data/graph-ml/hetero_graph.json (edges adjacency)
Gemeo loads these read-only. We never retrain inside swarm-py — retrain
runs in raras-app's `retrain-scheduled.sh` cron. Phase-2 GNN training
lives in `gemeo/train/` and writes its own artifacts under `gemeo/artifacts/`.
"""
from __future__ import annotations
import os
import json
import logging
from functools import lru_cache
from typing import Optional
logger = logging.getLogger("gemeo.bridge")
# Default location — env override allowed.
# Priority:
# 1. RARAS_APP_GRAPH_ML env (dev / custom location with full fp64 artifacts)
# 2. ./gemeo/data (the fp16 bundle shipped with the repo, ~54 MB)
# 3. /Users/dimas/raras-app/data/graph-ml (dev fallback for the original fp64)
_REPO_DATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
RARAS_APP_GRAPH_ML = os.environ.get(
"RARAS_APP_GRAPH_ML",
_REPO_DATA if os.path.exists(_REPO_DATA) else "/Users/dimas/raras-app/data/graph-ml",
)
# Pick the fp16 file if it exists (shipped variant), else the fp64 one (dev).
FUSED_FNAME = (
"fused_embeddings_fp16.npz"
if os.path.exists(os.path.join(RARAS_APP_GRAPH_ML, "fused_embeddings_fp16.npz"))
else "fused_embeddings.npz"
)
logger.info(f"gemeo.bridge: data dir={RARAS_APP_GRAPH_ML} fused={FUSED_FNAME}")
# In-process artifact cache — np.load(npz) is mmap-ish; we hold refs to keep arrays valid
_ARTIFACT_CACHE: dict = {}
def _path(name: str) -> str:
return os.path.join(RARAS_APP_GRAPH_ML, name)
def has_raras_artifacts() -> bool:
return all(
os.path.exists(_path(f))
for f in (FUSED_FNAME, "node_ids.json")
)
@lru_cache(maxsize=1)
def load_node_ids() -> dict:
"""Returns {'disease': [orpha,...], 'phenotype': [hpo,...], 'gene': [symbol,...]}."""
p = _path("node_ids.json")
if not os.path.exists(p):
logger.warning(f"node_ids.json not found at {p}")
return {"disease": [], "phenotype": [], "gene": []}
with open(p) as f:
return json.load(f)
@lru_cache(maxsize=1)
def load_node_index() -> dict:
"""Inverted: {'disease': {orpha: idx}, 'phenotype': {hpo: idx}, 'gene': {symbol: idx}}."""
ids = load_node_ids()
return {
kind: {nid: i for i, nid in enumerate(lst)}
for kind, lst in ids.items()
}
def load_fused_embeddings():
"""Returns dict: {'disease': np.ndarray, 'phenotype': np.ndarray, 'gene': np.ndarray} (3072-dim).
Auto-detects fp16-quantized variant (`fused_embeddings_fp16.npz`,
~41 MB) or the original fp64 (`fused_embeddings.npz`, ~649 MB).
The fp16 version is what ships in `gemeo/data/`; fp64 only exists
in the dev workstation export from raras-app.
"""
if "fused" in _ARTIFACT_CACHE:
return _ARTIFACT_CACHE["fused"]
p = _path(FUSED_FNAME)
if not os.path.exists(p):
logger.warning(f"{FUSED_FNAME} not found at {p}")
return None
try:
import numpy as np
npz = np.load(p)
out = {k: npz[k] for k in npz.files}
_ARTIFACT_CACHE["fused"] = out
return out
except Exception as e:
logger.error(f"Failed to load fused embeddings: {e}")
return None
def load_graph_embeddings():
"""64-dim PhenoGNet embeddings (lighter, for fast in-memory ops)."""
if "graph" in _ARTIFACT_CACHE:
return _ARTIFACT_CACHE["graph"]
p = _path("graph_embeddings.npz")
if not os.path.exists(p):
return None
try:
import numpy as np
npz = np.load(p)
out = {k: npz[k] for k in npz.files}
_ARTIFACT_CACHE["graph"] = out
return out
except Exception as e:
logger.error(f"Failed to load graph embeddings: {e}")
return None
@lru_cache(maxsize=1)
def load_hetero_graph() -> Optional[dict]:
"""The exported heterogeneous graph — node counts and edges by relation."""
p = _path("hetero_graph.json")
if not os.path.exists(p):
return None
try:
with open(p) as f:
return json.load(f)
except Exception as e:
logger.error(f"Failed to load hetero_graph.json: {e}")
return None
def lookup_disease_embedding(orpha: str, kind: str = "fused"):
emb = load_fused_embeddings() if kind == "fused" else load_graph_embeddings()
if emb is None:
return None
idx = load_node_index().get("disease", {}).get(orpha)
if idx is None:
return None
return emb["disease"][idx]
def lookup_phenotype_embedding(hpo: str, kind: str = "fused"):
emb = load_fused_embeddings() if kind == "fused" else load_graph_embeddings()
if emb is None:
return None
idx = load_node_index().get("phenotype", {}).get(hpo)
if idx is None:
return None
return emb["phenotype"][idx]
def lookup_gene_embedding(symbol: str, kind: str = "fused"):
emb = load_fused_embeddings() if kind == "fused" else load_graph_embeddings()
if emb is None:
return None
idx = load_node_index().get("gene", {}).get(symbol)
if idx is None:
return None
return emb["gene"][idx]
def stats() -> dict:
"""Diagnostic info for /api/gemeo/health."""
out = {
"graph_ml_dir": RARAS_APP_GRAPH_ML,
"available": has_raras_artifacts(),
"fused_loaded": "fused" in _ARTIFACT_CACHE,
"graph_loaded": "graph" in _ARTIFACT_CACHE,
}
ids = load_node_ids()
out["n_diseases"] = len(ids.get("disease", []))
out["n_phenotypes"] = len(ids.get("phenotype", []))
out["n_genes"] = len(ids.get("gene", []))
fused = load_fused_embeddings()
if fused is not None:
out["fused_dim"] = int(fused["disease"].shape[1]) if "disease" in fused else 0
return out
|