File size: 7,465 Bytes
5c389ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | # predictor.py
"""
FunGO β Prediction Engine
===========================
Loads XGBoost models once at startup.
Runs inference across all 3 ontologies (MFO, BPO, CCO).
Changes from original:
1. Added get_model_stats() β returns classifier counts per ontology
(used by /model/info endpoint).
2. Fixed open() to use context managers (file handles now closed).
3. tempfile.mktemp() replaced with NamedTemporaryFile (WSL fix).
4. Failed classifiers are counted and logged instead of silent pass.
5. Input shape validation in predict().
"""
import json
import logging
import pickle
import shutil
import subprocess
import tempfile
import numpy as np
from pathlib import Path
from config import PKL_DIR, VOCAB_PKL, IA_PKL, FEAT_META
logger = logging.getLogger(__name__)
ONTS = ["MFO", "BPO", "CCO"]
# ββ Globals βββββββββββββββββββββββββββββββββββββββββββββββββββ
_models_dict = None
_thresholds_dict = None
_ia_weights = None
_vocabularies = None
_top50_taxa = None
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββ
def _wsl_copy(src: Path) -> Path:
"""Copy file to temp path (WSL mounted-drive permission workaround)."""
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp:
tmp_path = Path(tmp.name)
shutil.copy2(str(src), str(tmp_path))
return tmp_path
def _safe_load(path: Path) -> object:
"""Load pickle with WSL permission workaround if needed."""
try:
subprocess.run(["chmod", "644", str(path)], check=False, capture_output=True)
except Exception:
pass
try:
with open(path, "rb") as fh:
return pickle.load(fh)
except PermissionError:
pass
tmp_path = None
try:
tmp_path = _wsl_copy(path)
with open(tmp_path, "rb") as fh:
return pickle.load(fh)
finally:
if tmp_path and tmp_path.exists():
tmp_path.unlink()
def _safe_read_json(path: Path) -> dict:
"""Read JSON with WSL permission workaround."""
try:
subprocess.run(["chmod", "644", str(path)], check=False, capture_output=True)
except Exception:
pass
for mode in ("r", "rb"):
try:
with open(path, mode) as fh:
raw = fh.read()
if isinstance(raw, bytes):
raw = raw.decode("utf-8", errors="replace")
return json.loads(raw)
except PermissionError:
continue
result = subprocess.run(["cat", str(path)], capture_output=True, text=True, check=True)
return json.loads(result.stdout)
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββ
def load_all():
"""
Load all models and supporting data into memory.
Call once at Flask startup (~30β120 s depending on hardware).
"""
global _models_dict, _thresholds_dict, _ia_weights, _vocabularies, _top50_taxa
logger.info("[predictor] Loading vocabularies β¦")
_vocabularies = _safe_load(VOCAB_PKL)
logger.info("[predictor] Loading IA weights β¦")
_ia_weights = _safe_load(IA_PKL)
logger.info("[predictor] IA weights: %d terms", len(_ia_weights))
meta = _safe_read_json(FEAT_META)
_top50_taxa = [int(t) for t in meta["taxonomy_info"]["top50_taxa"]]
logger.info("[predictor] Top-50 taxa loaded (%d)", len(_top50_taxa))
_models_dict = {}
_thresholds_dict = {}
for ont in ONTS:
pkl_path = PKL_DIR / f"models_{ont}.pkl"
size_mb = pkl_path.stat().st_size / 1e6
logger.info("[predictor] Loading %s (%.0f MB) β¦", pkl_path.name, size_mb)
raw = _safe_load(pkl_path)
first_key = next(iter(raw))
if first_key.startswith("GO:"):
models_d = raw
thresholds_d = {t: 0.5 for t in raw}
else:
clf_list = raw["models"]
term_list = raw["selected_terms"]
thr_raw = raw.get("thresholds", [0.5] * len(clf_list))
thr_list = (list(thr_raw) if not isinstance(thr_raw, dict)
else [thr_raw.get(t, 0.5) for t in term_list])
models_d = dict(zip(term_list, clf_list))
thresholds_d = dict(zip(term_list, thr_list))
_models_dict[ont] = models_d
_thresholds_dict[ont] = thresholds_d
logger.info("[predictor] %s: %d classifiers ready", ont, len(models_d))
logger.info("[predictor] All models loaded successfully.")
def get_top50_taxa() -> list:
if _top50_taxa is None:
raise RuntimeError("Models not loaded β call load_all() first.")
return _top50_taxa
def get_ia_weights() -> dict:
if _ia_weights is None:
raise RuntimeError("Models not loaded β call load_all() first.")
return _ia_weights
def get_model_stats() -> dict:
"""
Return classifier counts per ontology.
Used by GET /model/info endpoint.
Returns: {"MFO": 1500, "BPO": 1500, "CCO": 1133}
"""
if _models_dict is None:
raise RuntimeError("Models not loaded β call load_all() first.")
return {ont: len(models) for ont, models in _models_dict.items()}
def predict(X_final: np.ndarray, protein_ids: list) -> list:
"""
Run inference for all proteins across all 3 ontologies.
Parameters
----------
X_final : (N, 15411) float32 feature matrix
protein_ids : list of N protein ID strings
Returns
-------
List of raw prediction dicts:
[{protein_id, go_term, ontology, confidence, threshold}, β¦]
"""
if _models_dict is None:
raise RuntimeError("Models not loaded β call load_all() first.")
N = X_final.shape[0]
if N != len(protein_ids):
raise ValueError(
f"X_final has {N} rows but protein_ids has {len(protein_ids)} entries."
)
all_preds = []
failed_terms = 0
for ont in ONTS:
ont_models = _models_dict[ont]
ont_thresholds = _thresholds_dict[ont]
n_terms = len(ont_models)
logger.info("[predictor] %s β scoring %d terms Γ %d proteins β¦", ont, n_terms, N)
for go_term, clf in ont_models.items():
threshold = float(ont_thresholds.get(go_term, 0.5))
try:
proba = clf.predict_proba(X_final)[:, 1]
for i, pid in enumerate(protein_ids):
conf = float(proba[i])
if conf >= threshold:
all_preds.append({
"protein_id": pid,
"go_term": go_term,
"ontology": ont,
"confidence": round(conf, 4),
"threshold": round(threshold, 4),
})
except Exception as exc:
failed_terms += 1
logger.warning("[predictor] Classifier failed %s/%s: %s", ont, go_term, exc)
if failed_terms:
logger.warning("[predictor] Total failed classifiers: %d", failed_terms)
logger.info("[predictor] Inference complete β %d raw predictions", len(all_preds))
return all_preds
|