gemeo-twin-stack / src /gemeo /bridge.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Bridge to raras-app graph-ml artifacts.
raras-app trains PhenoGNet monthly and writes:
/Users/dimas/raras-app/data/graph-ml/biolord_embeddings.npz (768-dim, init)
/Users/dimas/raras-app/data/graph-ml/graph_embeddings.npz (64-dim, contrastive GNN)
/Users/dimas/raras-app/data/graph-ml/fused_embeddings.npz (3072-dim, final, Neo4j-indexed)
/Users/dimas/raras-app/data/graph-ml/node_ids.json (index → ORPHA/HPO/HGNC)
/Users/dimas/raras-app/data/graph-ml/hetero_graph.json (edges adjacency)
Gemeo loads these read-only. We never retrain inside swarm-py — retrain
runs in raras-app's `retrain-scheduled.sh` cron. Phase-2 GNN training
lives in `gemeo/train/` and writes its own artifacts under `gemeo/artifacts/`.
"""
from __future__ import annotations
import os
import json
import logging
from functools import lru_cache
from typing import Optional
logger = logging.getLogger("gemeo.bridge")
# Default location — env override allowed.
# Priority:
# 1. RARAS_APP_GRAPH_ML env (dev / custom location with full fp64 artifacts)
# 2. ./gemeo/data (the fp16 bundle shipped with the repo, ~54 MB)
# 3. /Users/dimas/raras-app/data/graph-ml (dev fallback for the original fp64)
_REPO_DATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
RARAS_APP_GRAPH_ML = os.environ.get(
"RARAS_APP_GRAPH_ML",
_REPO_DATA if os.path.exists(_REPO_DATA) else "/Users/dimas/raras-app/data/graph-ml",
)
# Pick the fp16 file if it exists (shipped variant), else the fp64 one (dev).
FUSED_FNAME = (
"fused_embeddings_fp16.npz"
if os.path.exists(os.path.join(RARAS_APP_GRAPH_ML, "fused_embeddings_fp16.npz"))
else "fused_embeddings.npz"
)
logger.info(f"gemeo.bridge: data dir={RARAS_APP_GRAPH_ML} fused={FUSED_FNAME}")
# In-process artifact cache — np.load(npz) is mmap-ish; we hold refs to keep arrays valid
_ARTIFACT_CACHE: dict = {}
def _path(name: str) -> str:
return os.path.join(RARAS_APP_GRAPH_ML, name)
def has_raras_artifacts() -> bool:
return all(
os.path.exists(_path(f))
for f in (FUSED_FNAME, "node_ids.json")
)
@lru_cache(maxsize=1)
def load_node_ids() -> dict:
"""Returns {'disease': [orpha,...], 'phenotype': [hpo,...], 'gene': [symbol,...]}."""
p = _path("node_ids.json")
if not os.path.exists(p):
logger.warning(f"node_ids.json not found at {p}")
return {"disease": [], "phenotype": [], "gene": []}
with open(p) as f:
return json.load(f)
@lru_cache(maxsize=1)
def load_node_index() -> dict:
"""Inverted: {'disease': {orpha: idx}, 'phenotype': {hpo: idx}, 'gene': {symbol: idx}}."""
ids = load_node_ids()
return {
kind: {nid: i for i, nid in enumerate(lst)}
for kind, lst in ids.items()
}
def load_fused_embeddings():
"""Returns dict: {'disease': np.ndarray, 'phenotype': np.ndarray, 'gene': np.ndarray} (3072-dim).
Auto-detects fp16-quantized variant (`fused_embeddings_fp16.npz`,
~41 MB) or the original fp64 (`fused_embeddings.npz`, ~649 MB).
The fp16 version is what ships in `gemeo/data/`; fp64 only exists
in the dev workstation export from raras-app.
"""
if "fused" in _ARTIFACT_CACHE:
return _ARTIFACT_CACHE["fused"]
p = _path(FUSED_FNAME)
if not os.path.exists(p):
logger.warning(f"{FUSED_FNAME} not found at {p}")
return None
try:
import numpy as np
npz = np.load(p)
out = {k: npz[k] for k in npz.files}
_ARTIFACT_CACHE["fused"] = out
return out
except Exception as e:
logger.error(f"Failed to load fused embeddings: {e}")
return None
def load_graph_embeddings():
"""64-dim PhenoGNet embeddings (lighter, for fast in-memory ops)."""
if "graph" in _ARTIFACT_CACHE:
return _ARTIFACT_CACHE["graph"]
p = _path("graph_embeddings.npz")
if not os.path.exists(p):
return None
try:
import numpy as np
npz = np.load(p)
out = {k: npz[k] for k in npz.files}
_ARTIFACT_CACHE["graph"] = out
return out
except Exception as e:
logger.error(f"Failed to load graph embeddings: {e}")
return None
@lru_cache(maxsize=1)
def load_hetero_graph() -> Optional[dict]:
"""The exported heterogeneous graph — node counts and edges by relation."""
p = _path("hetero_graph.json")
if not os.path.exists(p):
return None
try:
with open(p) as f:
return json.load(f)
except Exception as e:
logger.error(f"Failed to load hetero_graph.json: {e}")
return None
def lookup_disease_embedding(orpha: str, kind: str = "fused"):
emb = load_fused_embeddings() if kind == "fused" else load_graph_embeddings()
if emb is None:
return None
idx = load_node_index().get("disease", {}).get(orpha)
if idx is None:
return None
return emb["disease"][idx]
def lookup_phenotype_embedding(hpo: str, kind: str = "fused"):
emb = load_fused_embeddings() if kind == "fused" else load_graph_embeddings()
if emb is None:
return None
idx = load_node_index().get("phenotype", {}).get(hpo)
if idx is None:
return None
return emb["phenotype"][idx]
def lookup_gene_embedding(symbol: str, kind: str = "fused"):
emb = load_fused_embeddings() if kind == "fused" else load_graph_embeddings()
if emb is None:
return None
idx = load_node_index().get("gene", {}).get(symbol)
if idx is None:
return None
return emb["gene"][idx]
def stats() -> dict:
"""Diagnostic info for /api/gemeo/health."""
out = {
"graph_ml_dir": RARAS_APP_GRAPH_ML,
"available": has_raras_artifacts(),
"fused_loaded": "fused" in _ARTIFACT_CACHE,
"graph_loaded": "graph" in _ARTIFACT_CACHE,
}
ids = load_node_ids()
out["n_diseases"] = len(ids.get("disease", []))
out["n_phenotypes"] = len(ids.get("phenotype", []))
out["n_genes"] = len(ids.get("gene", []))
fused = load_fused_embeddings()
if fused is not None:
out["fused_dim"] = int(fused["disease"].shape[1]) if "disease" in fused else 0
return out