File size: 7,465 Bytes
5c389ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# predictor.py
"""
FunGO β€” Prediction Engine
===========================
Loads XGBoost models once at startup.
Runs inference across all 3 ontologies (MFO, BPO, CCO).

Changes from original:
  1. Added get_model_stats() β€” returns classifier counts per ontology
     (used by /model/info endpoint).
  2. Fixed open() to use context managers (file handles now closed).
  3. tempfile.mktemp() replaced with NamedTemporaryFile (WSL fix).
  4. Failed classifiers are counted and logged instead of silent pass.
  5. Input shape validation in predict().
"""

import json
import logging
import pickle
import shutil
import subprocess
import tempfile
import numpy as np
from pathlib import Path

from config import PKL_DIR, VOCAB_PKL, IA_PKL, FEAT_META

logger = logging.getLogger(__name__)
ONTS   = ["MFO", "BPO", "CCO"]

# ── Globals ───────────────────────────────────────────────────
_models_dict     = None
_thresholds_dict = None
_ia_weights      = None
_vocabularies    = None
_top50_taxa      = None


# ── Helpers ───────────────────────────────────────────────────

def _wsl_copy(src: Path) -> Path:
    """Copy file to temp path (WSL mounted-drive permission workaround)."""
    with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp:
        tmp_path = Path(tmp.name)
    shutil.copy2(str(src), str(tmp_path))
    return tmp_path


def _safe_load(path: Path) -> object:
    """Load pickle with WSL permission workaround if needed."""
    try:
        subprocess.run(["chmod", "644", str(path)], check=False, capture_output=True)
    except Exception:
        pass
    try:
        with open(path, "rb") as fh:
            return pickle.load(fh)
    except PermissionError:
        pass
    tmp_path = None
    try:
        tmp_path = _wsl_copy(path)
        with open(tmp_path, "rb") as fh:
            return pickle.load(fh)
    finally:
        if tmp_path and tmp_path.exists():
            tmp_path.unlink()


def _safe_read_json(path: Path) -> dict:
    """Read JSON with WSL permission workaround."""
    try:
        subprocess.run(["chmod", "644", str(path)], check=False, capture_output=True)
    except Exception:
        pass
    for mode in ("r", "rb"):
        try:
            with open(path, mode) as fh:
                raw = fh.read()
                if isinstance(raw, bytes):
                    raw = raw.decode("utf-8", errors="replace")
                return json.loads(raw)
        except PermissionError:
            continue
    result = subprocess.run(["cat", str(path)], capture_output=True, text=True, check=True)
    return json.loads(result.stdout)


# ── Public API ────────────────────────────────────────────────

def load_all():
    """
    Load all models and supporting data into memory.
    Call once at Flask startup (~30–120 s depending on hardware).
    """
    global _models_dict, _thresholds_dict, _ia_weights, _vocabularies, _top50_taxa

    logger.info("[predictor] Loading vocabularies …")
    _vocabularies = _safe_load(VOCAB_PKL)

    logger.info("[predictor] Loading IA weights …")
    _ia_weights = _safe_load(IA_PKL)
    logger.info("[predictor] IA weights: %d terms", len(_ia_weights))

    meta        = _safe_read_json(FEAT_META)
    _top50_taxa = [int(t) for t in meta["taxonomy_info"]["top50_taxa"]]
    logger.info("[predictor] Top-50 taxa loaded (%d)", len(_top50_taxa))

    _models_dict     = {}
    _thresholds_dict = {}

    for ont in ONTS:
        pkl_path = PKL_DIR / f"models_{ont}.pkl"
        size_mb  = pkl_path.stat().st_size / 1e6
        logger.info("[predictor] Loading %s (%.0f MB) …", pkl_path.name, size_mb)

        raw       = _safe_load(pkl_path)
        first_key = next(iter(raw))

        if first_key.startswith("GO:"):
            models_d     = raw
            thresholds_d = {t: 0.5 for t in raw}
        else:
            clf_list  = raw["models"]
            term_list = raw["selected_terms"]
            thr_raw   = raw.get("thresholds", [0.5] * len(clf_list))
            thr_list  = (list(thr_raw) if not isinstance(thr_raw, dict)
                         else [thr_raw.get(t, 0.5) for t in term_list])
            models_d     = dict(zip(term_list, clf_list))
            thresholds_d = dict(zip(term_list, thr_list))

        _models_dict[ont]     = models_d
        _thresholds_dict[ont] = thresholds_d
        logger.info("[predictor] %s: %d classifiers ready", ont, len(models_d))

    logger.info("[predictor] All models loaded successfully.")


def get_top50_taxa() -> list:
    if _top50_taxa is None:
        raise RuntimeError("Models not loaded β€” call load_all() first.")
    return _top50_taxa


def get_ia_weights() -> dict:
    if _ia_weights is None:
        raise RuntimeError("Models not loaded β€” call load_all() first.")
    return _ia_weights


def get_model_stats() -> dict:
    """
    Return classifier counts per ontology.
    Used by GET /model/info endpoint.
    Returns: {"MFO": 1500, "BPO": 1500, "CCO": 1133}
    """
    if _models_dict is None:
        raise RuntimeError("Models not loaded β€” call load_all() first.")
    return {ont: len(models) for ont, models in _models_dict.items()}


def predict(X_final: np.ndarray, protein_ids: list) -> list:
    """
    Run inference for all proteins across all 3 ontologies.

    Parameters
    ----------
    X_final     : (N, 15411) float32 feature matrix
    protein_ids : list of N protein ID strings

    Returns
    -------
    List of raw prediction dicts:
        [{protein_id, go_term, ontology, confidence, threshold}, …]
    """
    if _models_dict is None:
        raise RuntimeError("Models not loaded β€” call load_all() first.")

    N = X_final.shape[0]
    if N != len(protein_ids):
        raise ValueError(
            f"X_final has {N} rows but protein_ids has {len(protein_ids)} entries."
        )

    all_preds    = []
    failed_terms = 0

    for ont in ONTS:
        ont_models     = _models_dict[ont]
        ont_thresholds = _thresholds_dict[ont]
        n_terms        = len(ont_models)
        logger.info("[predictor] %s β€” scoring %d terms Γ— %d proteins …", ont, n_terms, N)

        for go_term, clf in ont_models.items():
            threshold = float(ont_thresholds.get(go_term, 0.5))
            try:
                proba = clf.predict_proba(X_final)[:, 1]
                for i, pid in enumerate(protein_ids):
                    conf = float(proba[i])
                    if conf >= threshold:
                        all_preds.append({
                            "protein_id": pid,
                            "go_term":    go_term,
                            "ontology":   ont,
                            "confidence": round(conf, 4),
                            "threshold":  round(threshold, 4),
                        })
            except Exception as exc:
                failed_terms += 1
                logger.warning("[predictor] Classifier failed %s/%s: %s", ont, go_term, exc)

    if failed_terms:
        logger.warning("[predictor] Total failed classifiers: %d", failed_terms)

    logger.info("[predictor] Inference complete β€” %d raw predictions", len(all_preds))
    return all_preds