File size: 14,525 Bytes
c557a60
 
5c389ab
 
b417c18
66291d9
 
 
 
 
244e9e1
1162af1
924dc78
5c389ab
 
 
 
 
 
 
 
 
 
 
c557a60
 
5c389ab
 
 
 
 
 
 
 
10e78e3
5c389ab
10e78e3
 
 
66291d9
10e78e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b417c18
10e78e3
 
 
 
 
 
 
 
 
 
 
 
 
5c389ab
c557a60
5c389ab
 
 
 
 
 
 
 
 
c557a60
 
 
 
5c389ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c557a60
 
5c389ab
 
c557a60
5c389ab
c557a60
5c389ab
 
 
c557a60
 
 
5c389ab
 
 
 
 
 
 
 
 
 
c557a60
5c389ab
 
 
 
 
 
c557a60
 
5c389ab
10e78e3
 
 
344ea31
10e78e3
 
 
 
5c389ab
 
c557a60
 
5c389ab
 
 
c557a60
10e78e3
5c389ab
c557a60
5c389ab
 
 
 
c557a60
 
 
5c389ab
 
 
 
 
c557a60
5c389ab
 
 
 
 
 
 
 
 
 
 
 
 
 
c557a60
10e78e3
5c389ab
 
 
 
 
 
 
c557a60
5c389ab
 
c557a60
 
5c389ab
c557a60
5c389ab
 
 
c557a60
5c389ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c557a60
5c389ab
 
 
 
c557a60
 
5c389ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c557a60
5c389ab
c557a60
 
5c389ab
c557a60
 
5c389ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b6edbd
10e78e3
5c389ab
10e78e3
 
 
 
 
c557a60
 
 
 
 
 
 
 
 
10e78e3
 
 
5c389ab
bc32c57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
# app.py β€” FunGO HuggingFace Space v2
import csv, io, logging, os, re as _re, sys, time
from collections import OrderedDict

os.environ.setdefault("FUNGO_PKL_DIR",    "/tmp/models")
os.environ.setdefault("FUNGO_VOCAB_PKL",  "/app/data/labels/vocabularies.pkl")
os.environ.setdefault("FUNGO_IA_PKL",     "/app/data/go_data/ia_weights.pkl")
os.environ.setdefault("FUNGO_FEAT_META",  "/app/data/features/feature_metadata.json")
os.environ.setdefault("FUNGO_MODEL_CACHE","/tmp/esm2_cache")
os.environ.setdefault("FUNGO_EMB_CACHE",  "/tmp/embedding_cache")
os.environ.setdefault("FUNGO_OFFLINE",    "0")
os.environ["HF_HUB_OFFLINE"] = "0"
os.environ.setdefault("HUGGING_FACE_HUB_TOKEN", os.environ.get("HF_TOKEN", ""))
os.environ.setdefault("FUNGO_DEBUG",      "0")
os.environ.setdefault("FUNGO_PORT",       "7860")

from flask import Flask, jsonify, request, Response
from flask_cors import CORS
import config
import predictor
import embedder
import filter as flt
import taxonomy

logging.basicConfig(level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s β€” %(message)s", datefmt="%H:%M:%S")
log = logging.getLogger("fungo.app")

app = Flask(__name__)
CORS(app)
app.config["MAX_CONTENT_LENGTH"] = 2 * 1024 * 1024

_csv_store: OrderedDict = OrderedDict()
_CSV_MAX = 50
_models_ready = False

# ── Download models from HF Model Repo ───────────────────────
def download_models_if_needed():
    """Download pkl files from Muteeba/FunGO-models if not present."""
    models_dir = "/tmp/models"
    os.makedirs(models_dir, exist_ok=True)
    
    files_needed = ["models_BPO.pkl", "models_MFO.pkl", "models_CCO.pkl"]
    all_present = all(
        os.path.exists(os.path.join(models_dir, f)) for f in files_needed
    )
    
    if all_present:
        log.info("Model files already present in /data/models/ β€” skipping download")
        return True
    
    log.info("Downloading model files from Muteeba/FunGO-models ...")
    try:
        from huggingface_hub import hf_hub_download
        for fname in files_needed:
            dest = os.path.join(models_dir, fname)
            if os.path.exists(dest):
                log.info("  %s already exists β€” skip", fname)
                continue
            log.info("  Downloading %s ...", fname)
            hf_hub_download(
                token=os.environ.get("HF_TOKEN"),
                repo_id="Muteeba/FunGO-models",
                filename=fname,
                repo_type="model",
                local_dir=models_dir,
            )
            log.info("  %s done!", fname)
        log.info("All model files downloaded!")
        return True
    except Exception as e:
        log.error("Model download failed: %s", e)
        return False

# ── CSV helpers ───────────────────────────────────────────────
def _store_csv(job_id, predictions):
    if len(_csv_store) >= _CSV_MAX: _csv_store.popitem(last=False)
    _csv_store[job_id] = {"predictions": predictions, "ts": time.time()}

def _make_csv(predictions):
    out = io.StringIO()
    w = csv.writer(out)
    w.writerow(["protein_id","go_term","ontology","ontology_label",
                "tier","tier_label","confidence","ia_weight","combined_score","threshold"])
    for pid, data in predictions.items():
        for p in data.get("all", []):
            w.writerow([pid,p.get("go_term",""),p.get("ontology",""),
                p.get("ontology_label",""),p.get("tier",""),p.get("tier_label",""),
                p.get("confidence",""),p.get("ia_weight",""),
                p.get("combined_score",""),p.get("threshold","")])
    return out.getvalue()

_OX_RE = _re.compile(r"OX=(\d+)")
def _parse_taxon_id(header):
    m = _OX_RE.search(header or "")
    return int(m.group(1)) if m else None

def parse_fasta(fasta_text):
    proteins, current_id, current_hdr, current_seq = [], None, None, []
    for raw_line in fasta_text.splitlines():
        line = raw_line.strip()
        if not line: continue
        if line.startswith(">"):
            if current_id is not None:
                seq = "".join(current_seq).upper()
                if seq:
                    proteins.append({"id":current_id,"seq":seq,
                        "header":current_hdr,"taxon_id":_parse_taxon_id(current_hdr)})
            current_hdr = line[1:].strip()
            parts = current_hdr.split("|")
            current_id = parts[1] if len(parts)>=3 else current_hdr.split()[0]
            current_seq = []
        else: current_seq.append(line)
    if current_id is not None:
        seq = "".join(current_seq).upper()
        if seq:
            proteins.append({"id":current_id,"seq":seq,
                "header":current_hdr,"taxon_id":_parse_taxon_id(current_hdr)})
    if not proteins: raise ValueError("No valid protein sequences found.")
    return proteins

def _run_prediction(fasta_text, taxon_id_override):
    proteins = parse_fasta(fasta_text)
    if len(proteins) > config.MAX_SEQUENCES:
        raise ValueError(f"Too many sequences. Max: {config.MAX_SEQUENCES}.")
    protein_ids = [p["id"] for p in proteins]
    sequences   = [p["seq"] for p in proteins]
    taxon_ids   = [taxon_id_override if taxon_id_override is not None
                   else p["taxon_id"] for p in proteins]
    t0 = time.perf_counter()
    X_esm   = embedder.extract(sequences)
    top50   = predictor.get_top50_taxa()
    X_final = embedder.build_features(X_esm, taxon_ids, top50)
    raw_preds  = predictor.predict(X_final, protein_ids)
    ia_weights = predictor.get_ia_weights()
    for p in raw_preds:
        p["ia_weight"] = round(float(ia_weights.get(p["go_term"],0.0)),4)
    return proteins, raw_preds, ia_weights, round(time.perf_counter()-t0,2)

# ── Routes ────────────────────────────────────────────────────
@app.route("/", methods=["GET"])
def index():
    return jsonify({"name":"FunGO API","version":"2.0.0",
        "status":"running","models_ready":_models_ready,
        "endpoints":["/health","/model/info","/taxonomy/search",
                     "/taxonomy/verify","/predict","/predict/csv","/predict/debug"]})

@app.route("/health", methods=["GET"])
def health():
    return jsonify({"status":"ok","device":config.DEVICE,
        "fp16":config.USE_FP16,"version":"2.0.0","models_ready":_models_ready})

@app.route("/model/info", methods=["GET"])
def model_info():
    if not _models_ready:
        return jsonify({"error":"Models not loaded yet."}), 503
    try: stats = predictor.get_model_stats()
    except RuntimeError as e: return jsonify({"error":str(e)}), 503
    return jsonify({"device":config.DEVICE,"fp16":config.USE_FP16,
        "model_name":config.MODEL_NAME,"ontologies":stats,
        "top50_taxa_count":len(predictor.get_top50_taxa()),
        "thresholds":{
            "STRONG":    {"min_ia":config.TIER_GOLD_IA,  "min_conf":config.TIER_GOLD_CONF},
            "MODERATE":  {"min_ia":config.TIER_GOOD_IA,  "min_conf":config.TIER_GOOD_CONF},
            "INDICATIVE":{"min_ia":config.TIER_SILVER_IA,"min_conf":config.TIER_SILVER_CONF},
        },"display_limit":flt.TOP_N_DISPLAY})

@app.route("/taxonomy/search", methods=["GET"])
def taxonomy_search():
    q = request.args.get("q","").strip()
    if len(q)<2: return jsonify({"error":"Query must be at least 2 characters."}), 400
    try: max_r = min(int(request.args.get("max_results",8)),20)
    except: max_r = 8
    return jsonify({"query":q,"results":taxonomy.search_species(q,max_results=max_r)})

@app.route("/taxonomy/verify", methods=["GET"])
def taxonomy_verify():
    raw = request.args.get("taxon_id","")
    if not raw: return jsonify({"error":"taxon_id required."}), 400
    try: taxon_id = int(raw)
    except: return jsonify({"error":f"Invalid taxon_id: '{raw}'"}), 400
    return jsonify(taxonomy.resolve_taxon(taxon_id, predictor.get_top50_taxa()))

@app.route("/predict", methods=["POST"])
def predict():
    if not _models_ready:
        return jsonify({"error":"Models not loaded yet."}), 503
    if not request.is_json: return jsonify({"error":"Content-Type must be application/json."}), 415
    body = request.get_json(silent=True) or {}
    fasta_text = body.get("fasta","").strip()
    if not fasta_text: return jsonify({"error":"'fasta' field is required."}), 400
    taxon_id_override = None
    if "taxon_id" in body:
        try: taxon_id_override = int(body["taxon_id"])
        except: return jsonify({"error":"Invalid taxon_id"}), 400
    try:
        proteins, raw_preds, ia_weights, elapsed = _run_prediction(fasta_text, taxon_id_override)
    except ValueError as e: return jsonify({"error":str(e)}), 400
    except RuntimeError as e: return jsonify({"error":str(e)}), 503
    except Exception as e:
        log.exception("Prediction error"); return jsonify({"error":str(e)}), 500
    protein_ids = [p["id"] for p in proteins]
    raw_by_pid  = {pid:[] for pid in protein_ids}
    for pred in raw_preds: raw_by_pid[pred["protein_id"]].append(pred)
    predictions, csv_data, total_display, total_all = {},{},0,0
    for prot in proteins:
        pid = prot["id"]
        res = flt.filter_predictions(raw_by_pid[pid], ia_weights)
        display, all_f = res["display"], res["all"]
        total_display += len(display); total_all += len(all_f)
        predictions[pid] = {"taxon_id":prot["taxon_id"],
            "summary":flt.summarise(display,all_f,pid),
            "display":display,"total_all":len(all_f)}
        csv_data[pid] = {"all":all_f}
    job_id = str(int(time.time()*1000))
    _store_csv(job_id, csv_data)
    return jsonify({"job_id":job_id,
        "metadata":{"n_proteins":len(protein_ids),"device":config.DEVICE,
            "total_raw_predictions":len(raw_preds),"total_filtered":total_all,
            "total_displayed":total_display,"display_limit":flt.TOP_N_DISPLAY,
            "elapsed_seconds":elapsed},
        "predictions":predictions})

@app.route("/predict/csv", methods=["GET"])
def download_csv():
    job_id = request.args.get("job_id","").strip()
    if not job_id: return jsonify({"error":"job_id required."}), 400
    job = _csv_store.get(job_id)
    if not job: return jsonify({"error":f"Job '{job_id}' not found."}), 404
    return Response(_make_csv(job["predictions"]),mimetype="text/csv",
        headers={"Content-Disposition":f"attachment; filename=fungo_{job_id}.csv"})

@app.route("/predict/debug", methods=["POST"])
def predict_debug():
    if not _models_ready:
        return jsonify({"error":"Models not loaded."}), 503
    if not request.is_json: return jsonify({"error":"Content-Type must be application/json."}), 415
    body = request.get_json(silent=True) or {}
    fasta_text = body.get("fasta","").strip()
    if not fasta_text: return jsonify({"error":"'fasta' required."}), 400
    taxon_id_override = None
    if "taxon_id" in body:
        try: taxon_id_override = int(body["taxon_id"])
        except: return jsonify({"error":"Invalid taxon_id"}), 400
    try:
        proteins, raw_preds, ia_weights, elapsed = _run_prediction(fasta_text, taxon_id_override)
    except Exception as e:
        log.exception("Debug error"); return jsonify({"error":str(e)}), 500
    protein_ids = [p["id"] for p in proteins]
    raw_by_pid  = {pid:[] for pid in protein_ids}
    for pred in raw_preds: raw_by_pid[pred["protein_id"]].append(pred)
    thr = {"STRONG":{"min_ia":config.TIER_GOLD_IA,"min_conf":config.TIER_GOLD_CONF},
           "MODERATE":{"min_ia":config.TIER_GOOD_IA,"min_conf":config.TIER_GOOD_CONF},
           "INDICATIVE":{"min_ia":config.TIER_SILVER_IA,"min_conf":config.TIER_SILVER_CONF}}
    predictions = {}
    for prot in proteins:
        pid = prot["id"]
        res = flt.filter_predictions(raw_by_pid[pid], ia_weights)
        display, all_f = res["display"], res["all"]
        accepted = {p["go_term"] for p in all_f}
        fo = []
        for pred in raw_by_pid[pid]:
            go = pred["go_term"]
            if go in accepted: continue
            ia,conf = pred.get("ia_weight",float(ia_weights.get(go,0.0))),pred["confidence"]
            if go in config.BLACKLIST_TERMS: reason="blacklisted"
            elif ia<=config.TIER_SILVER_IA: reason=f"ia_too_low (ia={ia:.4f})"
            elif conf<config.TIER_SILVER_CONF: reason=f"conf_too_low (conf={conf:.4f})"
            else: reason="below_all_tiers"
            fo.append({"go_term":go,"ontology":pred["ontology"],
                       "confidence":conf,"ia_weight":ia,"reason":reason})
        fo.sort(key=lambda x:-x["ia_weight"])
        predictions[pid] = {"taxon_id":prot["taxon_id"],
            "summary":flt.summarise(display,all_f,pid),
            "display":display,"all_filtered":all_f,
            "filtered_out":fo,"thresholds_used":thr}
    return jsonify({"metadata":{"n_proteins":len(protein_ids),"device":config.DEVICE,
        "total_raw":len(raw_preds),"elapsed_seconds":elapsed},"predictions":predictions})

@app.errorhandler(404)
def not_found(e): return jsonify({"error":"Not found."}), 404
@app.errorhandler(413)
def too_large(e): return jsonify({"error":"Request too large."}), 413
@app.errorhandler(500)
def internal(e):
    log.exception("Unhandled error"); return jsonify({"error":"Internal server error."}), 500

if __name__ == "__main__":

    log.info("FunGO v2.0 β€” HuggingFace Space starting ...")
    config.ensure_dirs()

    # Download models if not present
    download_models_if_needed()

    # Load models
    paths_ok = config.validate_paths()
    if paths_ok:
        try:
            predictor.load_all()
            _models_ready = True
            log.info("Models loaded successfully!")
        except Exception as e:
            log.error("Model loading failed: %s", e)
    else:
        log.warning("Some paths missing β€” predictions disabled")

    log.info("Serving on port 7860 ...")
    app.run(host="0.0.0.0", port=7860, debug=False)

# ── Module-level startup (gunicorn compatible) ────────────────
log.info("FunGO v2.0 starting ...")
config.ensure_dirs()
download_models_if_needed()
paths_ok = config.validate_paths()
if paths_ok:
    try:
        predictor.load_all()
        _models_ready = True
        log.info("Models loaded successfully!")
    except Exception as e:
        log.error("Model loading failed: %s", e)
else:
    log.warning("Some paths missing β€” predictions disabled")