FunGO / predictor.py
Muteeba's picture
FunGO v2.0 backend
5c389ab
# predictor.py
"""
FunGO β€” Prediction Engine
===========================
Loads XGBoost models once at startup.
Runs inference across all 3 ontologies (MFO, BPO, CCO).
Changes from original:
1. Added get_model_stats() β€” returns classifier counts per ontology
(used by /model/info endpoint).
2. Fixed open() to use context managers (file handles now closed).
3. tempfile.mktemp() replaced with NamedTemporaryFile (WSL fix).
4. Failed classifiers are counted and logged instead of silent pass.
5. Input shape validation in predict().
"""
import json
import logging
import pickle
import shutil
import subprocess
import tempfile
import numpy as np
from pathlib import Path
from config import PKL_DIR, VOCAB_PKL, IA_PKL, FEAT_META
logger = logging.getLogger(__name__)
ONTS = ["MFO", "BPO", "CCO"]
# ── Globals ───────────────────────────────────────────────────
_models_dict = None
_thresholds_dict = None
_ia_weights = None
_vocabularies = None
_top50_taxa = None
# ── Helpers ───────────────────────────────────────────────────
def _wsl_copy(src: Path) -> Path:
"""Copy file to temp path (WSL mounted-drive permission workaround)."""
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp:
tmp_path = Path(tmp.name)
shutil.copy2(str(src), str(tmp_path))
return tmp_path
def _safe_load(path: Path) -> object:
"""Load pickle with WSL permission workaround if needed."""
try:
subprocess.run(["chmod", "644", str(path)], check=False, capture_output=True)
except Exception:
pass
try:
with open(path, "rb") as fh:
return pickle.load(fh)
except PermissionError:
pass
tmp_path = None
try:
tmp_path = _wsl_copy(path)
with open(tmp_path, "rb") as fh:
return pickle.load(fh)
finally:
if tmp_path and tmp_path.exists():
tmp_path.unlink()
def _safe_read_json(path: Path) -> dict:
"""Read JSON with WSL permission workaround."""
try:
subprocess.run(["chmod", "644", str(path)], check=False, capture_output=True)
except Exception:
pass
for mode in ("r", "rb"):
try:
with open(path, mode) as fh:
raw = fh.read()
if isinstance(raw, bytes):
raw = raw.decode("utf-8", errors="replace")
return json.loads(raw)
except PermissionError:
continue
result = subprocess.run(["cat", str(path)], capture_output=True, text=True, check=True)
return json.loads(result.stdout)
# ── Public API ────────────────────────────────────────────────
def load_all():
"""
Load all models and supporting data into memory.
Call once at Flask startup (~30–120 s depending on hardware).
"""
global _models_dict, _thresholds_dict, _ia_weights, _vocabularies, _top50_taxa
logger.info("[predictor] Loading vocabularies …")
_vocabularies = _safe_load(VOCAB_PKL)
logger.info("[predictor] Loading IA weights …")
_ia_weights = _safe_load(IA_PKL)
logger.info("[predictor] IA weights: %d terms", len(_ia_weights))
meta = _safe_read_json(FEAT_META)
_top50_taxa = [int(t) for t in meta["taxonomy_info"]["top50_taxa"]]
logger.info("[predictor] Top-50 taxa loaded (%d)", len(_top50_taxa))
_models_dict = {}
_thresholds_dict = {}
for ont in ONTS:
pkl_path = PKL_DIR / f"models_{ont}.pkl"
size_mb = pkl_path.stat().st_size / 1e6
logger.info("[predictor] Loading %s (%.0f MB) …", pkl_path.name, size_mb)
raw = _safe_load(pkl_path)
first_key = next(iter(raw))
if first_key.startswith("GO:"):
models_d = raw
thresholds_d = {t: 0.5 for t in raw}
else:
clf_list = raw["models"]
term_list = raw["selected_terms"]
thr_raw = raw.get("thresholds", [0.5] * len(clf_list))
thr_list = (list(thr_raw) if not isinstance(thr_raw, dict)
else [thr_raw.get(t, 0.5) for t in term_list])
models_d = dict(zip(term_list, clf_list))
thresholds_d = dict(zip(term_list, thr_list))
_models_dict[ont] = models_d
_thresholds_dict[ont] = thresholds_d
logger.info("[predictor] %s: %d classifiers ready", ont, len(models_d))
logger.info("[predictor] All models loaded successfully.")
def get_top50_taxa() -> list:
if _top50_taxa is None:
raise RuntimeError("Models not loaded β€” call load_all() first.")
return _top50_taxa
def get_ia_weights() -> dict:
if _ia_weights is None:
raise RuntimeError("Models not loaded β€” call load_all() first.")
return _ia_weights
def get_model_stats() -> dict:
"""
Return classifier counts per ontology.
Used by GET /model/info endpoint.
Returns: {"MFO": 1500, "BPO": 1500, "CCO": 1133}
"""
if _models_dict is None:
raise RuntimeError("Models not loaded β€” call load_all() first.")
return {ont: len(models) for ont, models in _models_dict.items()}
def predict(X_final: np.ndarray, protein_ids: list) -> list:
"""
Run inference for all proteins across all 3 ontologies.
Parameters
----------
X_final : (N, 15411) float32 feature matrix
protein_ids : list of N protein ID strings
Returns
-------
List of raw prediction dicts:
[{protein_id, go_term, ontology, confidence, threshold}, …]
"""
if _models_dict is None:
raise RuntimeError("Models not loaded β€” call load_all() first.")
N = X_final.shape[0]
if N != len(protein_ids):
raise ValueError(
f"X_final has {N} rows but protein_ids has {len(protein_ids)} entries."
)
all_preds = []
failed_terms = 0
for ont in ONTS:
ont_models = _models_dict[ont]
ont_thresholds = _thresholds_dict[ont]
n_terms = len(ont_models)
logger.info("[predictor] %s β€” scoring %d terms Γ— %d proteins …", ont, n_terms, N)
for go_term, clf in ont_models.items():
threshold = float(ont_thresholds.get(go_term, 0.5))
try:
proba = clf.predict_proba(X_final)[:, 1]
for i, pid in enumerate(protein_ids):
conf = float(proba[i])
if conf >= threshold:
all_preds.append({
"protein_id": pid,
"go_term": go_term,
"ontology": ont,
"confidence": round(conf, 4),
"threshold": round(threshold, 4),
})
except Exception as exc:
failed_terms += 1
logger.warning("[predictor] Classifier failed %s/%s: %s", ont, go_term, exc)
if failed_terms:
logger.warning("[predictor] Total failed classifiers: %d", failed_terms)
logger.info("[predictor] Inference complete β€” %d raw predictions", len(all_preds))
return all_preds