Muteeba commited on
Commit
5c389ab
Β·
1 Parent(s): 4e8a676

FunGO v2.0 backend

Browse files
Files changed (8) hide show
  1. app.py +278 -0
  2. config.py +104 -0
  3. embedder.py +187 -0
  4. filter.py +152 -0
  5. hf_README.md +41 -0
  6. predictor.py +216 -0
  7. requirements.txt +9 -0
  8. taxonomy.py +200 -0
app.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py β€” FunGO HuggingFace Space
2
+ """
3
+ FunGO v2.0 β€” HuggingFace Spaces Deployment
4
+ =============================================
5
+ Flask API running on port 7860.
6
+ Model files loaded from /data/ (HF persistent storage).
7
+
8
+ To upload model files:
9
+ pip install huggingface_hub
10
+ huggingface-cli login
11
+ huggingface-cli upload Muteeba/FunGO ./pipeline_outputs/models /data/models --repo-type=space
12
+ huggingface-cli upload Muteeba/FunGO ./pipeline_outputs/labels /data/labels --repo-type=space
13
+ huggingface-cli upload Muteeba/FunGO ./pipeline_outputs/go_data /data/go_data --repo-type=space
14
+ huggingface-cli upload Muteeba/FunGO ./pipeline_outputs/features /data/features --repo-type=space
15
+ huggingface-cli upload Muteeba/FunGO /mnt/e/repeat/embeddings/model_cache /data/esm2_cache --repo-type=space
16
+ """
17
+
18
+ import csv
19
+ import io
20
+ import logging
21
+ import os
22
+ import re as _re
23
+ import sys
24
+ import time
25
+ from collections import OrderedDict
26
+
27
+ # ── HuggingFace paths ─────────────────────────────────────────
28
+ os.environ.setdefault("FUNGO_PKL_DIR", "/data/models")
29
+ os.environ.setdefault("FUNGO_VOCAB_PKL", "/data/labels/vocabularies.pkl")
30
+ os.environ.setdefault("FUNGO_IA_PKL", "/data/go_data/ia_weights.pkl")
31
+ os.environ.setdefault("FUNGO_FEAT_META", "/data/features/feature_metadata.json")
32
+ os.environ.setdefault("FUNGO_MODEL_CACHE","/data/esm2_cache")
33
+ os.environ.setdefault("FUNGO_EMB_CACHE", "/data/embedding_cache")
34
+ os.environ.setdefault("FUNGO_OFFLINE", "1")
35
+ os.environ.setdefault("FUNGO_DEBUG", "0")
36
+ os.environ.setdefault("FUNGO_PORT", "7860")
37
+
38
+ from flask import Flask, jsonify, request, Response
39
+ from flask_cors import CORS
40
+
41
+ import config
42
+ import predictor
43
+ import embedder
44
+ import filter as flt
45
+ import taxonomy
46
+
47
+ logging.basicConfig(
48
+ level=logging.INFO,
49
+ format="%(asctime)s [%(levelname)s] %(name)s β€” %(message)s",
50
+ datefmt="%H:%M:%S",
51
+ )
52
+ log = logging.getLogger("fungo.app")
53
+
54
+ app = Flask(__name__)
55
+ CORS(app)
56
+ app.config["MAX_CONTENT_LENGTH"] = 2 * 1024 * 1024
57
+
58
+ _csv_store: OrderedDict = OrderedDict()
59
+ _CSV_MAX = 50
60
+
61
+ def _store_csv(job_id, predictions):
62
+ if len(_csv_store) >= _CSV_MAX:
63
+ _csv_store.popitem(last=False)
64
+ _csv_store[job_id] = {"predictions": predictions, "ts": time.time()}
65
+
66
+ def _make_csv(predictions):
67
+ out = io.StringIO()
68
+ w = csv.writer(out)
69
+ w.writerow(["protein_id","go_term","ontology","ontology_label",
70
+ "tier","tier_label","confidence","ia_weight","combined_score","threshold"])
71
+ for pid, data in predictions.items():
72
+ for p in data.get("all", []):
73
+ w.writerow([pid, p.get("go_term",""), p.get("ontology",""),
74
+ p.get("ontology_label",""), p.get("tier",""), p.get("tier_label",""),
75
+ p.get("confidence",""), p.get("ia_weight",""),
76
+ p.get("combined_score",""), p.get("threshold","")])
77
+ return out.getvalue()
78
+
79
+ _OX_RE = _re.compile(r"OX=(\d+)")
80
+
81
+ def _parse_taxon_id(header):
82
+ m = _OX_RE.search(header or "")
83
+ return int(m.group(1)) if m else None
84
+
85
+ def parse_fasta(fasta_text):
86
+ proteins, current_id, current_hdr, current_seq = [], None, None, []
87
+ for raw_line in fasta_text.splitlines():
88
+ line = raw_line.strip()
89
+ if not line: continue
90
+ if line.startswith(">"):
91
+ if current_id is not None:
92
+ seq = "".join(current_seq).upper()
93
+ if seq:
94
+ proteins.append({"id": current_id, "seq": seq,
95
+ "header": current_hdr, "taxon_id": _parse_taxon_id(current_hdr)})
96
+ current_hdr = line[1:].strip()
97
+ parts = current_hdr.split("|")
98
+ current_id = parts[1] if len(parts) >= 3 else current_hdr.split()[0]
99
+ current_seq = []
100
+ else:
101
+ current_seq.append(line)
102
+ if current_id is not None:
103
+ seq = "".join(current_seq).upper()
104
+ if seq:
105
+ proteins.append({"id": current_id, "seq": seq,
106
+ "header": current_hdr, "taxon_id": _parse_taxon_id(current_hdr)})
107
+ if not proteins:
108
+ raise ValueError("No valid protein sequences found in FASTA input.")
109
+ return proteins
110
+
111
+ def _run_prediction(fasta_text, taxon_id_override):
112
+ proteins = parse_fasta(fasta_text)
113
+ if len(proteins) > config.MAX_SEQUENCES:
114
+ raise ValueError(f"Too many sequences. Max: {config.MAX_SEQUENCES}.")
115
+ protein_ids = [p["id"] for p in proteins]
116
+ sequences = [p["seq"] for p in proteins]
117
+ taxon_ids = [taxon_id_override if taxon_id_override is not None
118
+ else p["taxon_id"] for p in proteins]
119
+ log.info("Proteins: %s | Taxon IDs: %s", protein_ids, taxon_ids)
120
+ t0 = time.perf_counter()
121
+ X_esm = embedder.extract(sequences)
122
+ top50 = predictor.get_top50_taxa()
123
+ X_final = embedder.build_features(X_esm, taxon_ids, top50)
124
+ raw_preds = predictor.predict(X_final, protein_ids)
125
+ ia_weights = predictor.get_ia_weights()
126
+ for p in raw_preds:
127
+ p["ia_weight"] = round(float(ia_weights.get(p["go_term"], 0.0)), 4)
128
+ return proteins, raw_preds, ia_weights, round(time.perf_counter() - t0, 2)
129
+
130
+ @app.route("/health", methods=["GET"])
131
+ def health():
132
+ return jsonify({"status":"ok","device":config.DEVICE,"fp16":config.USE_FP16,"version":"2.0.0"})
133
+
134
+ @app.route("/model/info", methods=["GET"])
135
+ def model_info():
136
+ try: stats = predictor.get_model_stats()
137
+ except RuntimeError as e: return jsonify({"error": str(e)}), 503
138
+ return jsonify({"device":config.DEVICE,"fp16":config.USE_FP16,
139
+ "model_name":config.MODEL_NAME,"ontologies":stats,
140
+ "top50_taxa_count":len(predictor.get_top50_taxa()),
141
+ "thresholds":{
142
+ "STRONG": {"min_ia":config.TIER_GOLD_IA, "min_conf":config.TIER_GOLD_CONF},
143
+ "MODERATE": {"min_ia":config.TIER_GOOD_IA, "min_conf":config.TIER_GOOD_CONF},
144
+ "INDICATIVE":{"min_ia":config.TIER_SILVER_IA, "min_conf":config.TIER_SILVER_CONF},
145
+ },"display_limit":flt.TOP_N_DISPLAY})
146
+
147
+ @app.route("/taxonomy/search", methods=["GET"])
148
+ def taxonomy_search():
149
+ q = request.args.get("q","").strip()
150
+ if len(q) < 2: return jsonify({"error":"Query must be at least 2 characters."}), 400
151
+ try: max_r = min(int(request.args.get("max_results",8)),20)
152
+ except: max_r = 8
153
+ return jsonify({"query":q,"results":taxonomy.search_species(q,max_results=max_r)})
154
+
155
+ @app.route("/taxonomy/verify", methods=["GET"])
156
+ def taxonomy_verify():
157
+ raw = request.args.get("taxon_id","")
158
+ if not raw: return jsonify({"error":"taxon_id required."}), 400
159
+ try: taxon_id = int(raw)
160
+ except: return jsonify({"error":f"Invalid taxon_id: '{raw}'"}), 400
161
+ return jsonify(taxonomy.resolve_taxon(taxon_id, predictor.get_top50_taxa()))
162
+
163
+ @app.route("/predict", methods=["POST"])
164
+ def predict():
165
+ if not request.is_json: return jsonify({"error":"Content-Type must be application/json."}), 415
166
+ body = request.get_json(silent=True) or {}
167
+ fasta_text = body.get("fasta","").strip()
168
+ if not fasta_text: return jsonify({"error":"'fasta' field is required."}), 400
169
+ taxon_id_override = None
170
+ if "taxon_id" in body:
171
+ try: taxon_id_override = int(body["taxon_id"])
172
+ except: return jsonify({"error":f"Invalid taxon_id"}), 400
173
+ try:
174
+ proteins, raw_preds, ia_weights, elapsed = _run_prediction(fasta_text, taxon_id_override)
175
+ except ValueError as e: return jsonify({"error": str(e)}), 400
176
+ except RuntimeError as e: return jsonify({"error": str(e)}), 503
177
+ except Exception as e:
178
+ log.exception("Prediction error"); return jsonify({"error": str(e)}), 500
179
+
180
+ protein_ids = [p["id"] for p in proteins]
181
+ raw_by_pid = {pid:[] for pid in protein_ids}
182
+ for pred in raw_preds: raw_by_pid[pred["protein_id"]].append(pred)
183
+
184
+ predictions, csv_data, total_display, total_all = {}, {}, 0, 0
185
+ for prot in proteins:
186
+ pid = prot["id"]
187
+ res = flt.filter_predictions(raw_by_pid[pid], ia_weights)
188
+ display, all_f = res["display"], res["all"]
189
+ total_display += len(display); total_all += len(all_f)
190
+ predictions[pid] = {"taxon_id":prot["taxon_id"],
191
+ "summary":flt.summarise(display,all_f,pid),
192
+ "display":display,"total_all":len(all_f)}
193
+ csv_data[pid] = {"all":all_f}
194
+
195
+ job_id = str(int(time.time()*1000))
196
+ _store_csv(job_id, csv_data)
197
+ return jsonify({"job_id":job_id,
198
+ "metadata":{"n_proteins":len(protein_ids),"device":config.DEVICE,
199
+ "total_raw_predictions":len(raw_preds),"total_filtered":total_all,
200
+ "total_displayed":total_display,"display_limit":flt.TOP_N_DISPLAY,
201
+ "elapsed_seconds":elapsed},
202
+ "predictions":predictions})
203
+
204
+ @app.route("/predict/csv", methods=["GET"])
205
+ def download_csv():
206
+ job_id = request.args.get("job_id","").strip()
207
+ if not job_id: return jsonify({"error":"job_id required."}), 400
208
+ job = _csv_store.get(job_id)
209
+ if not job: return jsonify({"error":f"Job '{job_id}' not found."}), 404
210
+ return Response(_make_csv(job["predictions"]), mimetype="text/csv",
211
+ headers={"Content-Disposition":f"attachment; filename=fungo_{job_id}.csv"})
212
+
213
+ @app.route("/predict/debug", methods=["POST"])
214
+ def predict_debug():
215
+ if not request.is_json: return jsonify({"error":"Content-Type must be application/json."}), 415
216
+ body = request.get_json(silent=True) or {}
217
+ fasta_text = body.get("fasta","").strip()
218
+ if not fasta_text: return jsonify({"error":"'fasta' required."}), 400
219
+ taxon_id_override = None
220
+ if "taxon_id" in body:
221
+ try: taxon_id_override = int(body["taxon_id"])
222
+ except: return jsonify({"error":"Invalid taxon_id"}), 400
223
+ try:
224
+ proteins, raw_preds, ia_weights, elapsed = _run_prediction(fasta_text, taxon_id_override)
225
+ except ValueError as e: return jsonify({"error":str(e)}), 400
226
+ except RuntimeError as e: return jsonify({"error":str(e)}), 503
227
+ except Exception as e:
228
+ log.exception("Debug error"); return jsonify({"error":str(e)}), 500
229
+
230
+ protein_ids = [p["id"] for p in proteins]
231
+ raw_by_pid = {pid:[] for pid in protein_ids}
232
+ for pred in raw_preds: raw_by_pid[pred["protein_id"]].append(pred)
233
+
234
+ thr = {"STRONG":{"min_ia":config.TIER_GOLD_IA,"min_conf":config.TIER_GOLD_CONF},
235
+ "MODERATE":{"min_ia":config.TIER_GOOD_IA,"min_conf":config.TIER_GOOD_CONF},
236
+ "INDICATIVE":{"min_ia":config.TIER_SILVER_IA,"min_conf":config.TIER_SILVER_CONF}}
237
+ predictions = {}
238
+ for prot in proteins:
239
+ pid = prot["id"]
240
+ res = flt.filter_predictions(raw_by_pid[pid], ia_weights)
241
+ display, all_f = res["display"], res["all"]
242
+ accepted = {p["go_term"] for p in all_f}
243
+ fo = []
244
+ for pred in raw_by_pid[pid]:
245
+ go = pred["go_term"]
246
+ if go in accepted: continue
247
+ ia, conf = pred.get("ia_weight", float(ia_weights.get(go,0.0))), pred["confidence"]
248
+ if go in config.BLACKLIST_TERMS: reason="blacklisted"
249
+ elif ia <= config.TIER_SILVER_IA: reason=f"ia_too_low (ia={ia:.4f})"
250
+ elif conf < config.TIER_SILVER_CONF: reason=f"conf_too_low (conf={conf:.4f})"
251
+ else: reason="below_all_tiers"
252
+ fo.append({"go_term":go,"ontology":pred["ontology"],"confidence":conf,
253
+ "ia_weight":ia,"reason":reason})
254
+ fo.sort(key=lambda x:-x["ia_weight"])
255
+ predictions[pid] = {"taxon_id":prot["taxon_id"],
256
+ "summary":flt.summarise(display,all_f,pid),
257
+ "display":display,"all_filtered":all_f,
258
+ "filtered_out":fo,"thresholds_used":thr}
259
+ return jsonify({"metadata":{"n_proteins":len(protein_ids),"device":config.DEVICE,
260
+ "total_raw":len(raw_preds),"elapsed_seconds":elapsed},"predictions":predictions})
261
+
262
+ @app.errorhandler(404)
263
+ def not_found(e): return jsonify({"error":"Not found."}), 404
264
+ @app.errorhandler(413)
265
+ def too_large(e): return jsonify({"error":"Request too large."}), 413
266
+ @app.errorhandler(500)
267
+ def internal(e):
268
+ log.exception("Unhandled error"); return jsonify({"error":"Internal server error."}), 500
269
+
270
+ if __name__ == "__main__":
271
+ log.info("FunGO v2.0 β€” HuggingFace Space starting …")
272
+ config.ensure_dirs()
273
+ if not config.validate_paths():
274
+ log.error("Model paths missing!")
275
+ sys.exit(1)
276
+ predictor.load_all()
277
+ log.info("Models loaded. Serving on port 7860 …")
278
+ app.run(host="0.0.0.0", port=7860, debug=False)
config.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ """
3
+ FunGO Backend β€” Central Configuration
4
+ ======================================
5
+ ONLY change paths in this file. Nothing else needs editing.
6
+
7
+ How to use:
8
+ - Update PKL_DIR, VOCAB_PKL, IA_PKL, FEAT_META to point to your model files
9
+ - Update MODEL_CACHE_DIR to point to your ESM2 weights cache
10
+ - All other settings work as-is
11
+ """
12
+
13
+ import logging
14
+ import os
15
+ from pathlib import Path
16
+ import torch
17
+
18
+ logger = logging.getLogger("config")
19
+
20
+ # ── DEVICE (auto-detected) ────────────────────────────────────
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+ USE_FP16 = DEVICE == "cuda"
23
+
24
+ # ── MODEL PATHS β€” UPDATE THESE TO MATCH YOUR SYSTEM ──────────
25
+ PKL_DIR = Path(os.environ.get("FUNGO_PKL_DIR", "/mnt/f/research/thesis/pipeline_outputs/models"))
26
+ VOCAB_PKL = Path(os.environ.get("FUNGO_VOCAB_PKL", "/mnt/f/research/thesis/pipeline_outputs/labels/vocabularies.pkl"))
27
+ IA_PKL = Path(os.environ.get("FUNGO_IA_PKL", "/mnt/f/research/thesis/pipeline_outputs/go_data/ia_weights.pkl"))
28
+ FEAT_META = Path(os.environ.get("FUNGO_FEAT_META", "/mnt/f/research/thesis/pipeline_outputs/features/feature_metadata.json"))
29
+
30
+ # ── ESM2 SETTINGS ─────────────────────────────────────────────
31
+ MODEL_CACHE_DIR = Path(os.environ.get("FUNGO_MODEL_CACHE", "/mnt/e/repeat/embeddings/model_cache"))
32
+ MODEL_NAME = "facebook/esm2_t36_3B_UR50D"
33
+ LAYERS_TO_USE = [30, 31, 32, 33, 34, 35]
34
+ MAX_SEQ_LENGTH = 1400
35
+ BATCH_SIZE = 4 if DEVICE == "cpu" else 16
36
+ TRANSFORMERS_OFFLINE = os.environ.get("FUNGO_OFFLINE", "1")
37
+
38
+ # ── EMBEDDING CACHE ───────────────────────────────────────────
39
+ EMB_CACHE_DIR = Path(os.environ.get("FUNGO_EMB_CACHE", "./embedding_cache"))
40
+
41
+ # ── FILTER THRESHOLDS (do not change) ────────────────────────
42
+ BLACKLIST_TERMS = {
43
+ "GO:0003674","GO:0008150","GO:0005575","GO:0005488",
44
+ "GO:0043226","GO:0043229","GO:0043227","GO:0043231",
45
+ "GO:0110165","GO:0005622","GO:0005623","GO:0044464",
46
+ "GO:0043232","GO:0044424","GO:0009987","GO:0065007",
47
+ "GO:0050794","GO:0019222","GO:0060255","GO:0080090",
48
+ "GO:0050789",
49
+ }
50
+
51
+ # Strong Evidence (was GOLD)
52
+ TIER_GOLD_IA = 5.0
53
+ TIER_GOLD_CONF = 0.30
54
+
55
+ # Moderate Evidence (was GOOD)
56
+ TIER_GOOD_IA = 2.0
57
+ TIER_GOOD_CONF = 0.50
58
+
59
+ # Indicative (was SILVER)
60
+ TIER_SILVER_IA = 1.0
61
+ TIER_SILVER_CONF = 0.65
62
+
63
+ # ── NCBI TAXONOMY API ─────────────────────────────────────────
64
+ NCBI_SEARCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
65
+ NCBI_SUMMARY_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi"
66
+ NCBI_FETCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
67
+ NCBI_TOOL = "FunGO"
68
+ NCBI_EMAIL = "fungo@research.com"
69
+
70
+ # ── FLASK ─────────────────────────────────────────────────────
71
+ PORT = int(os.environ.get("FUNGO_PORT", 5000))
72
+ DEBUG = os.environ.get("FUNGO_DEBUG", "0") == "1"
73
+ MAX_SEQUENCES = int(os.environ.get("FUNGO_MAX_SEQ", 10))
74
+
75
+
76
+ # ── Runtime helpers ───────────────────────────────────────────
77
+
78
+ def ensure_dirs():
79
+ """Create required runtime directories. Called once at startup."""
80
+ EMB_CACHE_DIR.mkdir(parents=True, exist_ok=True)
81
+ logger.info("[config] EMB_CACHE_DIR ready β†’ %s", EMB_CACHE_DIR)
82
+
83
+
84
+ def validate_paths() -> bool:
85
+ """
86
+ Check that all required model files exist.
87
+ Returns True if all found, False if any missing.
88
+ Called at startup before loading models.
89
+ """
90
+ required = {
91
+ "PKL_DIR": PKL_DIR,
92
+ "VOCAB_PKL": VOCAB_PKL,
93
+ "IA_PKL": IA_PKL,
94
+ "FEAT_META": FEAT_META,
95
+ "MODEL_CACHE_DIR": MODEL_CACHE_DIR,
96
+ }
97
+ all_ok = True
98
+ for name, path in required.items():
99
+ if path.exists():
100
+ logger.info("[config] βœ“ %-18s β†’ %s", name, path)
101
+ else:
102
+ logger.error("[config] βœ— %-18s β†’ %s (NOT FOUND)", name, path)
103
+ all_ok = False
104
+ return all_ok
embedder.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # embedder.py
2
+ """
3
+ FunGO β€” ESM2 Embedding Extractor
4
+ ==================================
5
+ Extracts layers 30–35 from ESM2-t36-3B.
6
+ - Auto-detects CPU vs GPU
7
+ - Caches embeddings per session to avoid re-extraction
8
+ - Lazy model loading (loaded only on first request)
9
+ """
10
+
11
+ import os
12
+ import hashlib
13
+ import numpy as np
14
+ import torch
15
+ from pathlib import Path
16
+ from config import (
17
+ MODEL_CACHE_DIR, MODEL_NAME, LAYERS_TO_USE,
18
+ MAX_SEQ_LENGTH, BATCH_SIZE, DEVICE, USE_FP16,
19
+ EMB_CACHE_DIR,
20
+ )
21
+
22
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
23
+ os.environ["HF_DATASETS_OFFLINE"] = "1"
24
+ os.environ["TRANSFORMERS_CACHE"] = str(MODEL_CACHE_DIR)
25
+ os.environ["HF_HOME"] = str(MODEL_CACHE_DIR)
26
+
27
+ N_ESM_DIMS = len(LAYERS_TO_USE) * 2560 # 6 Γ— 2560 = 15,360
28
+
29
+ # ── Lazy globals ──────────────────────────────────────────────────────────
30
+ _tokenizer = None
31
+ _model = None
32
+
33
+
34
+ def _load_model():
35
+ """Load ESM2 tokenizer and model (only once)."""
36
+ global _tokenizer, _model
37
+
38
+ if _tokenizer is not None and _model is not None:
39
+ return _tokenizer, _model
40
+
41
+ print(f"[embedder] Loading ESM2 from local cache β†’ {MODEL_CACHE_DIR}")
42
+ print(f"[embedder] Device: {DEVICE} | FP16: {USE_FP16}")
43
+
44
+ from transformers import EsmTokenizer, EsmModel
45
+
46
+ _tokenizer = EsmTokenizer.from_pretrained(
47
+ MODEL_NAME,
48
+ cache_dir=MODEL_CACHE_DIR,
49
+ local_files_only=True,
50
+ )
51
+ _model = EsmModel.from_pretrained(
52
+ MODEL_NAME,
53
+ cache_dir=MODEL_CACHE_DIR,
54
+ output_hidden_states=True,
55
+ local_files_only=True,
56
+ )
57
+
58
+ if USE_FP16:
59
+ _model = _model.to(DEVICE).half()
60
+ else:
61
+ _model = _model.to(DEVICE)
62
+
63
+ _model.eval()
64
+ for p in _model.parameters():
65
+ p.requires_grad = False
66
+
67
+ print(f"[embedder] Model ready on {DEVICE}")
68
+ return _tokenizer, _model
69
+
70
+
71
+ def _seq_cache_key(sequences: list) -> str:
72
+ """Hash sequences to use as cache filename."""
73
+ joined = "|".join(f"{s[:50]}{len(s)}" for s in sequences)
74
+ return hashlib.md5(joined.encode()).hexdigest()[:16]
75
+
76
+
77
+ def _load_cache(key: str):
78
+ path = EMB_CACHE_DIR / f"{key}.npy"
79
+ if path.exists():
80
+ return np.load(str(path))
81
+ return None
82
+
83
+
84
+ def _save_cache(key: str, arr: np.ndarray):
85
+ np.save(str(EMB_CACHE_DIR / f"{key}.npy"), arr)
86
+
87
+
88
+ def extract(sequences: list) -> np.ndarray:
89
+ """
90
+ Extract ESM2 embeddings for a list of sequences.
91
+ Returns np.ndarray of shape (N, 15360), dtype float32.
92
+ Sequences are truncated to MAX_SEQ_LENGTH if needed.
93
+ Uses cache to avoid re-extraction.
94
+ """
95
+ # Truncate sequences
96
+ seqs_truncated = [s[:MAX_SEQ_LENGTH] for s in sequences]
97
+ N = len(seqs_truncated)
98
+
99
+ # Check cache
100
+ cache_key = _seq_cache_key(seqs_truncated)
101
+ cached_emb = _load_cache(cache_key)
102
+ if cached_emb is not None and cached_emb.shape == (N, N_ESM_DIMS):
103
+ print(f"[embedder] Cache hit β€” skipping extraction for {N} sequences")
104
+ return cached_emb.astype(np.float32)
105
+
106
+ print(f"[embedder] Extracting embeddings: {N} sequences on {DEVICE}")
107
+
108
+ tokenizer, model = _load_model()
109
+
110
+ X = np.zeros((N, N_ESM_DIMS), dtype=np.float32)
111
+ current_batch = BATCH_SIZE
112
+
113
+ with torch.no_grad():
114
+ i = 0
115
+ while i < N:
116
+ batch_end = min(i + current_batch, N)
117
+ batch_seqs = seqs_truncated[i:batch_end]
118
+
119
+ try:
120
+ inputs = tokenizer(
121
+ batch_seqs,
122
+ return_tensors="pt",
123
+ padding=True,
124
+ truncation=True,
125
+ max_length=MAX_SEQ_LENGTH + 2,
126
+ )
127
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
128
+
129
+ outputs = model(**inputs)
130
+ hidden_states = outputs.hidden_states
131
+
132
+ for j, seq in enumerate(batch_seqs):
133
+ seq_len = len(seq)
134
+ layer_vecs = []
135
+
136
+ for layer_idx in LAYERS_TO_USE:
137
+ h = hidden_states[layer_idx][j, 1:seq_len + 1, :]
138
+ v = h.mean(dim=0)
139
+ if DEVICE == "cuda":
140
+ v = v.float().cpu().numpy()
141
+ else:
142
+ v = v.numpy()
143
+ layer_vecs.append(v)
144
+
145
+ X[i + j] = np.concatenate(layer_vecs)
146
+
147
+ i += len(batch_seqs)
148
+ print(f"[embedder] {i}/{N} done")
149
+
150
+ except RuntimeError as e:
151
+ if "out of memory" in str(e).lower() and current_batch > 1:
152
+ current_batch = max(1, current_batch // 2)
153
+ print(f"[embedder] OOM β€” batch size reduced to {current_batch}")
154
+ if DEVICE == "cuda":
155
+ torch.cuda.empty_cache()
156
+ else:
157
+ raise
158
+
159
+ # Sanitise
160
+ bad = np.isnan(X).sum() + np.isinf(X).sum()
161
+ if bad > 0:
162
+ X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
163
+
164
+ # Save cache
165
+ _save_cache(cache_key, X)
166
+ print(f"[embedder] Saved to cache: {cache_key}")
167
+
168
+ return X
169
+
170
+
171
+ def build_features(X_esm: np.ndarray, taxon_ids: list,
172
+ top50_taxa: list) -> np.ndarray:
173
+ """
174
+ Append 51-dim taxonomy features to ESM embeddings.
175
+ Returns (N, 15411) feature matrix.
176
+ """
177
+ N = X_esm.shape[0]
178
+ taxon_to_i = {t: i for i, t in enumerate(top50_taxa)}
179
+ X_tax = np.zeros((N, 51), dtype=np.float32)
180
+
181
+ for i, tx in enumerate(taxon_ids):
182
+ if tx is not None and tx in taxon_to_i:
183
+ X_tax[i, taxon_to_i[tx]] = 1.0
184
+ else:
185
+ X_tax[i, 50] = 1.0 # unknown species flag
186
+
187
+ return np.hstack([X_esm, X_tax])
filter.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filter.py
2
+ """
3
+ FunGO β€” Smart Tier Filtering
4
+ ==============================
5
+ Removes generic/root GO terms and assigns evidence tiers
6
+ to remaining predictions.
7
+
8
+ Changes from original:
9
+ 1. Tier names updated:
10
+ GOLD β†’ STRONG (Strong Evidence)
11
+ GOOD β†’ MODERATE (Moderate Evidence)
12
+ SILVER β†’ INDICATIVE
13
+ 2. Combined score = ia_weight Γ— confidence
14
+ Used for ranking β€” more scientifically sound.
15
+ 3. filter_predictions() returns a dict with two keys:
16
+ "display" β€” top 20 by combined score (for UI screen)
17
+ "all" β€” full filtered list (for CSV download)
18
+ 4. summarise() updated to use new tier keys.
19
+ 5. Blacklist + IA/confidence thresholds β†’ completely unchanged.
20
+ """
21
+
22
+ import logging
23
+ from config import (
24
+ BLACKLIST_TERMS,
25
+ TIER_GOLD_IA, TIER_GOLD_CONF,
26
+ TIER_GOOD_IA, TIER_GOOD_CONF,
27
+ TIER_SILVER_IA, TIER_SILVER_CONF,
28
+ )
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ ONT_LABELS = {
33
+ "MFO": "Molecular Function",
34
+ "BPO": "Biological Process",
35
+ "CCO": "Cellular Component",
36
+ }
37
+
38
+ TIER_LABELS = {
39
+ "STRONG": "Strong Evidence",
40
+ "MODERATE": "Moderate Evidence",
41
+ "INDICATIVE": "Indicative",
42
+ }
43
+
44
+ TIER_RANK = {"STRONG": 0, "MODERATE": 1, "INDICATIVE": 2}
45
+
46
+ # Max predictions shown on screen per protein
47
+ TOP_N_DISPLAY = 20
48
+
49
+
50
+ def assign_tier(go_term: str, ia: float, confidence: float) -> str:
51
+ """
52
+ Assign evidence tier. Thresholds unchanged from original.
53
+
54
+ Returns: "STRONG" | "MODERATE" | "INDICATIVE" | "NOISE"
55
+ """
56
+ if go_term in BLACKLIST_TERMS:
57
+ return "NOISE"
58
+ if ia > TIER_GOLD_IA and confidence >= TIER_GOLD_CONF:
59
+ return "STRONG"
60
+ if ia > TIER_GOOD_IA and confidence >= TIER_GOOD_CONF:
61
+ return "MODERATE"
62
+ if ia > TIER_SILVER_IA and confidence >= TIER_SILVER_CONF:
63
+ return "INDICATIVE"
64
+ return "NOISE"
65
+
66
+
67
+ def combined_score(ia: float, confidence: float) -> float:
68
+ """
69
+ Ranking score = ia_weight Γ— confidence.
70
+ Balances specificity (IA) and model certainty (confidence).
71
+ """
72
+ return round(ia * confidence, 6)
73
+
74
+
75
+ def filter_predictions(raw_predictions: list, ia_weights: dict) -> dict:
76
+ """
77
+ Filter raw predictions and return display + full sets.
78
+
79
+ Returns
80
+ -------
81
+ {
82
+ "display": top-20 predictions (sorted by combined_score desc),
83
+ "all": all filtered predictions (for CSV)
84
+ }
85
+
86
+ Each prediction dict contains:
87
+ go_term, ontology, ontology_label, confidence, threshold,
88
+ ia_weight, combined_score, tier, tier_rank, tier_label
89
+ """
90
+ filtered = []
91
+
92
+ for pred in raw_predictions:
93
+ go_term = pred["go_term"]
94
+ confidence = pred["confidence"]
95
+ ia = float(ia_weights.get(go_term, 0.0))
96
+ tier = assign_tier(go_term, ia, confidence)
97
+
98
+ if tier == "NOISE":
99
+ continue
100
+
101
+ if tier not in TIER_RANK:
102
+ logger.warning("Unknown tier %r for %s β€” skipping", tier, go_term)
103
+ continue
104
+
105
+ score = combined_score(ia, confidence)
106
+
107
+ filtered.append({
108
+ **pred,
109
+ "ia_weight": round(ia, 4),
110
+ "combined_score": score,
111
+ "tier": tier,
112
+ "tier_rank": TIER_RANK[tier],
113
+ "tier_label": TIER_LABELS[tier],
114
+ "ontology_label": ONT_LABELS.get(pred["ontology"], pred["ontology"]),
115
+ })
116
+
117
+ # Sort by combined score descending, tier_rank as tiebreaker
118
+ filtered.sort(key=lambda x: (-x["combined_score"], x["tier_rank"]))
119
+
120
+ return {
121
+ "display": filtered[:TOP_N_DISPLAY],
122
+ "all": filtered,
123
+ }
124
+
125
+
126
+ def summarise(filtered_display: list, all_filtered: list, protein_id: str) -> dict:
127
+ """
128
+ Per-protein summary. Counts are over ALL filtered (not just top-20).
129
+ """
130
+ ont_counts = {"MFO": 0, "BPO": 0, "CCO": 0}
131
+ tier_counts = {"STRONG": 0, "MODERATE": 0, "INDICATIVE": 0}
132
+
133
+ for p in all_filtered:
134
+ ont = p.get("ontology", "")
135
+ if ont in ont_counts:
136
+ ont_counts[ont] += 1
137
+ t = p.get("tier", "")
138
+ if t in tier_counts:
139
+ tier_counts[t] += 1
140
+
141
+ n = len(all_filtered)
142
+ return {
143
+ "protein_id": protein_id,
144
+ "total_filtered": n,
145
+ "displayed": len(filtered_display),
146
+ "by_ontology": ont_counts,
147
+ "by_tier": tier_counts,
148
+ "has_strong_evidence": tier_counts["STRONG"] > 0,
149
+ "avg_confidence": round(sum(p["confidence"] for p in all_filtered) / n, 4) if n else 0.0,
150
+ "avg_ia": round(sum(p["ia_weight"] for p in all_filtered) / n, 4) if n else 0.0,
151
+ "avg_combined_score": round(sum(p["combined_score"] for p in all_filtered) / n, 4) if n else 0.0,
152
+ }
hf_README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FunGO
3
+ emoji: 🧬
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Protein Function Prediction using ESM2 + XGBoost
12
+ ---
13
+
14
+ # FunGO β€” Protein Function Prediction
15
+
16
+ **Beyond Prediction β€” Understanding Function.**
17
+
18
+ FunGO predicts Gene Ontology (GO) terms for protein sequences using:
19
+ - **ESM2-t36-3B** β€” protein language model embeddings (layers 30–35)
20
+ - **XGBoost classifiers** β€” 4,133 GO-term specific models
21
+ - **Evidence-tiered filtering** β€” Strong / Moderate / Indicative
22
+
23
+ ## Evidence Tiers
24
+
25
+ | Tier | IA Weight | Confidence | Description |
26
+ |------|-----------|------------|-------------|
27
+ | Strong Evidence | > 5.0 | β‰₯ 0.30 | Highly specific GO term |
28
+ | Moderate Evidence | > 2.0 | β‰₯ 0.50 | Moderately specific term |
29
+ | Indicative | > 1.0 | β‰₯ 0.65 | Lower specificity, high confidence |
30
+
31
+ ## Ontologies Covered
32
+
33
+ - **MFO** β€” Molecular Function
34
+ - **BPO** β€” Biological Process
35
+ - **CCO** β€” Cellular Component
36
+
37
+ ## Development Team
38
+
39
+ - **Dr. Beenish Maqsood** β€” Principal Investigator, School of Biochemistry and Biotechnology, University of the Punjab
40
+ - **Dr. Naeem Mahmood** β€” Co-Supervisor, School of Biochemistry and Biotechnology, University of the Punjab
41
+ - **Muteeba Azhar** β€” Lead Developer, MS Researcher, University of the Punjab
predictor.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # predictor.py
2
+ """
3
+ FunGO β€” Prediction Engine
4
+ ===========================
5
+ Loads XGBoost models once at startup.
6
+ Runs inference across all 3 ontologies (MFO, BPO, CCO).
7
+
8
+ Changes from original:
9
+ 1. Added get_model_stats() β€” returns classifier counts per ontology
10
+ (used by /model/info endpoint).
11
+ 2. Fixed open() to use context managers (file handles now closed).
12
+ 3. tempfile.mktemp() replaced with NamedTemporaryFile (WSL fix).
13
+ 4. Failed classifiers are counted and logged instead of silent pass.
14
+ 5. Input shape validation in predict().
15
+ """
16
+
17
+ import json
18
+ import logging
19
+ import pickle
20
+ import shutil
21
+ import subprocess
22
+ import tempfile
23
+ import numpy as np
24
+ from pathlib import Path
25
+
26
+ from config import PKL_DIR, VOCAB_PKL, IA_PKL, FEAT_META
27
+
28
+ logger = logging.getLogger(__name__)
29
+ ONTS = ["MFO", "BPO", "CCO"]
30
+
31
+ # ── Globals ───────────────────────────────────────────────────
32
+ _models_dict = None
33
+ _thresholds_dict = None
34
+ _ia_weights = None
35
+ _vocabularies = None
36
+ _top50_taxa = None
37
+
38
+
39
+ # ── Helpers ───────────────────────────────────────────────────
40
+
41
+ def _wsl_copy(src: Path) -> Path:
42
+ """Copy file to temp path (WSL mounted-drive permission workaround)."""
43
+ with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp:
44
+ tmp_path = Path(tmp.name)
45
+ shutil.copy2(str(src), str(tmp_path))
46
+ return tmp_path
47
+
48
+
49
+ def _safe_load(path: Path) -> object:
50
+ """Load pickle with WSL permission workaround if needed."""
51
+ try:
52
+ subprocess.run(["chmod", "644", str(path)], check=False, capture_output=True)
53
+ except Exception:
54
+ pass
55
+ try:
56
+ with open(path, "rb") as fh:
57
+ return pickle.load(fh)
58
+ except PermissionError:
59
+ pass
60
+ tmp_path = None
61
+ try:
62
+ tmp_path = _wsl_copy(path)
63
+ with open(tmp_path, "rb") as fh:
64
+ return pickle.load(fh)
65
+ finally:
66
+ if tmp_path and tmp_path.exists():
67
+ tmp_path.unlink()
68
+
69
+
70
+ def _safe_read_json(path: Path) -> dict:
71
+ """Read JSON with WSL permission workaround."""
72
+ try:
73
+ subprocess.run(["chmod", "644", str(path)], check=False, capture_output=True)
74
+ except Exception:
75
+ pass
76
+ for mode in ("r", "rb"):
77
+ try:
78
+ with open(path, mode) as fh:
79
+ raw = fh.read()
80
+ if isinstance(raw, bytes):
81
+ raw = raw.decode("utf-8", errors="replace")
82
+ return json.loads(raw)
83
+ except PermissionError:
84
+ continue
85
+ result = subprocess.run(["cat", str(path)], capture_output=True, text=True, check=True)
86
+ return json.loads(result.stdout)
87
+
88
+
89
+ # ── Public API ────────────────────────────────────────────────
90
+
91
+ def load_all():
92
+ """
93
+ Load all models and supporting data into memory.
94
+ Call once at Flask startup (~30–120 s depending on hardware).
95
+ """
96
+ global _models_dict, _thresholds_dict, _ia_weights, _vocabularies, _top50_taxa
97
+
98
+ logger.info("[predictor] Loading vocabularies …")
99
+ _vocabularies = _safe_load(VOCAB_PKL)
100
+
101
+ logger.info("[predictor] Loading IA weights …")
102
+ _ia_weights = _safe_load(IA_PKL)
103
+ logger.info("[predictor] IA weights: %d terms", len(_ia_weights))
104
+
105
+ meta = _safe_read_json(FEAT_META)
106
+ _top50_taxa = [int(t) for t in meta["taxonomy_info"]["top50_taxa"]]
107
+ logger.info("[predictor] Top-50 taxa loaded (%d)", len(_top50_taxa))
108
+
109
+ _models_dict = {}
110
+ _thresholds_dict = {}
111
+
112
+ for ont in ONTS:
113
+ pkl_path = PKL_DIR / f"models_{ont}.pkl"
114
+ size_mb = pkl_path.stat().st_size / 1e6
115
+ logger.info("[predictor] Loading %s (%.0f MB) …", pkl_path.name, size_mb)
116
+
117
+ raw = _safe_load(pkl_path)
118
+ first_key = next(iter(raw))
119
+
120
+ if first_key.startswith("GO:"):
121
+ models_d = raw
122
+ thresholds_d = {t: 0.5 for t in raw}
123
+ else:
124
+ clf_list = raw["models"]
125
+ term_list = raw["selected_terms"]
126
+ thr_raw = raw.get("thresholds", [0.5] * len(clf_list))
127
+ thr_list = (list(thr_raw) if not isinstance(thr_raw, dict)
128
+ else [thr_raw.get(t, 0.5) for t in term_list])
129
+ models_d = dict(zip(term_list, clf_list))
130
+ thresholds_d = dict(zip(term_list, thr_list))
131
+
132
+ _models_dict[ont] = models_d
133
+ _thresholds_dict[ont] = thresholds_d
134
+ logger.info("[predictor] %s: %d classifiers ready", ont, len(models_d))
135
+
136
+ logger.info("[predictor] All models loaded successfully.")
137
+
138
+
139
+ def get_top50_taxa() -> list:
140
+ if _top50_taxa is None:
141
+ raise RuntimeError("Models not loaded β€” call load_all() first.")
142
+ return _top50_taxa
143
+
144
+
145
+ def get_ia_weights() -> dict:
146
+ if _ia_weights is None:
147
+ raise RuntimeError("Models not loaded β€” call load_all() first.")
148
+ return _ia_weights
149
+
150
+
151
+ def get_model_stats() -> dict:
152
+ """
153
+ Return classifier counts per ontology.
154
+ Used by GET /model/info endpoint.
155
+ Returns: {"MFO": 1500, "BPO": 1500, "CCO": 1133}
156
+ """
157
+ if _models_dict is None:
158
+ raise RuntimeError("Models not loaded β€” call load_all() first.")
159
+ return {ont: len(models) for ont, models in _models_dict.items()}
160
+
161
+
162
+ def predict(X_final: np.ndarray, protein_ids: list) -> list:
163
+ """
164
+ Run inference for all proteins across all 3 ontologies.
165
+
166
+ Parameters
167
+ ----------
168
+ X_final : (N, 15411) float32 feature matrix
169
+ protein_ids : list of N protein ID strings
170
+
171
+ Returns
172
+ -------
173
+ List of raw prediction dicts:
174
+ [{protein_id, go_term, ontology, confidence, threshold}, …]
175
+ """
176
+ if _models_dict is None:
177
+ raise RuntimeError("Models not loaded β€” call load_all() first.")
178
+
179
+ N = X_final.shape[0]
180
+ if N != len(protein_ids):
181
+ raise ValueError(
182
+ f"X_final has {N} rows but protein_ids has {len(protein_ids)} entries."
183
+ )
184
+
185
+ all_preds = []
186
+ failed_terms = 0
187
+
188
+ for ont in ONTS:
189
+ ont_models = _models_dict[ont]
190
+ ont_thresholds = _thresholds_dict[ont]
191
+ n_terms = len(ont_models)
192
+ logger.info("[predictor] %s β€” scoring %d terms Γ— %d proteins …", ont, n_terms, N)
193
+
194
+ for go_term, clf in ont_models.items():
195
+ threshold = float(ont_thresholds.get(go_term, 0.5))
196
+ try:
197
+ proba = clf.predict_proba(X_final)[:, 1]
198
+ for i, pid in enumerate(protein_ids):
199
+ conf = float(proba[i])
200
+ if conf >= threshold:
201
+ all_preds.append({
202
+ "protein_id": pid,
203
+ "go_term": go_term,
204
+ "ontology": ont,
205
+ "confidence": round(conf, 4),
206
+ "threshold": round(threshold, 4),
207
+ })
208
+ except Exception as exc:
209
+ failed_terms += 1
210
+ logger.warning("[predictor] Classifier failed %s/%s: %s", ont, go_term, exc)
211
+
212
+ if failed_terms:
213
+ logger.warning("[predictor] Total failed classifiers: %d", failed_terms)
214
+
215
+ logger.info("[predictor] Inference complete β€” %d raw predictions", len(all_preds))
216
+ return all_preds
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ flask>=3.0.0
2
+ flask-cors>=4.0.0
3
+ numpy>=1.24.0
4
+ torch>=2.0.0
5
+ transformers>=4.35.0
6
+ xgboost>=2.0.0
7
+ requests>=2.31.0
8
+ gunicorn>=21.0.0
9
+ gradio>=4.0.0
taxonomy.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # taxonomy.py
2
+ """
3
+ FunGO β€” NCBI Taxonomy Service
4
+ ===============================
5
+ Species name β†’ taxon ID lookup and reverse lookup.
6
+
7
+ Fixes applied:
8
+ 1. UID string/int consistency β€” result_map keys are always strings,
9
+ now explicitly uses str(uid) so 9606 never resolves to {}.
10
+ 2. Species-rank preference β€” results sorted so "species" rank
11
+ appears before "genus". Prevents 9605 (Homo genus) showing
12
+ before 9606 (Homo sapiens species).
13
+ 3. Exact-name boost β€” exact query match moved to position 0.
14
+ 4. Cache key includes max_results to prevent stale smaller lists.
15
+ 5. xml.etree.ElementTree replaces fragile regex XML parsing.
16
+ 6. Retry logic β€” 3 attempts with 2s gap on connection errors.
17
+ """
18
+
19
+ import logging
20
+ import time
21
+ import xml.etree.ElementTree as ET
22
+ import requests
23
+
24
+ from config import (
25
+ NCBI_SEARCH_URL, NCBI_SUMMARY_URL, NCBI_FETCH_URL,
26
+ NCBI_TOOL, NCBI_EMAIL,
27
+ )
28
+
29
+ logger = logging.getLogger(__name__)
30
+ HEADERS = {"User-Agent": f"FunGO/1.0 ({NCBI_EMAIL})"}
31
+ TIMEOUT = 10
32
+ RETRIES = 3
33
+ RETRY_DELAY = 2
34
+
35
+ _RANK_PRIORITY = {
36
+ "species": 0, "subspecies": 1, "varietas": 2,
37
+ "forma": 3, "strain": 4, "no rank": 5,
38
+ "genus": 6, "family": 7, "order": 8,
39
+ "class": 9, "phylum": 10, "kingdom": 11, "superkingdom": 12,
40
+ }
41
+
42
+ def _rank_priority(rank: str) -> int:
43
+ return _RANK_PRIORITY.get(rank.lower().strip(), 99)
44
+
45
+ _search_cache: dict = {}
46
+ _id_to_info_cache: dict = {}
47
+
48
+
49
+ def _ncbi_get(url: str, params: dict) -> requests.Response:
50
+ last_exc = None
51
+ for attempt in range(1, RETRIES + 1):
52
+ try:
53
+ resp = requests.get(url, params=params, timeout=TIMEOUT, headers=HEADERS)
54
+ resp.raise_for_status()
55
+ return resp
56
+ except requests.RequestException as exc:
57
+ last_exc = exc
58
+ if attempt < RETRIES:
59
+ logger.warning("[taxonomy] Request error (attempt %d/%d): %s β€” retrying in %ds",
60
+ attempt, RETRIES, exc, RETRY_DELAY)
61
+ time.sleep(RETRY_DELAY)
62
+ raise last_exc
63
+
64
+
65
+ def search_species(query: str, max_results: int = 8) -> list:
66
+ """
67
+ Search NCBI taxonomy by species name.
68
+ Returns [{taxon_id, scientific_name, common_name, rank, division}]
69
+ Sorted: species rank first, exact name match at position 0.
70
+ """
71
+ query = query.strip()
72
+ if len(query) < 2:
73
+ return []
74
+
75
+ cache_key = (query.lower(), max_results)
76
+ if cache_key in _search_cache:
77
+ return _search_cache[cache_key]
78
+
79
+ try:
80
+ search_resp = _ncbi_get(NCBI_SEARCH_URL, {
81
+ "db": "taxonomy", "term": query,
82
+ "retmax": max_results, "retmode": "json",
83
+ "tool": NCBI_TOOL, "email": NCBI_EMAIL,
84
+ })
85
+ ids = search_resp.json().get("esearchresult", {}).get("idlist", [])
86
+
87
+ if not ids:
88
+ _search_cache[cache_key] = []
89
+ return []
90
+
91
+ summary_resp = _ncbi_get(NCBI_SUMMARY_URL, {
92
+ "db": "taxonomy", "id": ",".join(ids),
93
+ "retmode": "json", "tool": NCBI_TOOL, "email": NCBI_EMAIL,
94
+ })
95
+ result_map = summary_resp.json().get("result", {})
96
+ uids = result_map.get("uids", ids)
97
+
98
+ results = []
99
+ for uid in uids:
100
+ item = result_map.get(str(uid), {}) # FIX: explicit str()
101
+ if not item:
102
+ continue
103
+ results.append({
104
+ "taxon_id": int(uid),
105
+ "scientific_name": item.get("scientificname", ""),
106
+ "common_name": item.get("commonname", ""),
107
+ "rank": item.get("rank", ""),
108
+ "division": item.get("division", ""),
109
+ })
110
+
111
+ # FIX: sort by rank β€” species before genus
112
+ results.sort(key=lambda r: _rank_priority(r.get("rank", "")))
113
+
114
+ # FIX: exact name match β†’ front of list
115
+ q_lower = query.lower()
116
+ exact = [r for r in results if r["scientific_name"].lower() == q_lower]
117
+ rest = [r for r in results if r["scientific_name"].lower() != q_lower]
118
+ results = exact + rest
119
+
120
+ _search_cache[cache_key] = results
121
+ logger.info("[taxonomy] search %r β†’ %d results", query, len(results))
122
+ return results
123
+
124
+ except Exception as exc:
125
+ logger.error("[taxonomy] search_species(%r) failed: %s", query, exc)
126
+ return [{"error": str(exc)}]
127
+
128
+
129
+ def get_taxon_info(taxon_id: int) -> dict:
130
+ """
131
+ Reverse lookup: taxon ID β†’ full species info with lineage.
132
+ Uses xml.etree.ElementTree β€” handles multi-line XML correctly.
133
+ """
134
+ if taxon_id in _id_to_info_cache:
135
+ return _id_to_info_cache[taxon_id]
136
+
137
+ base = {
138
+ "taxon_id": taxon_id, "scientific_name": "",
139
+ "common_name": "", "rank": "", "division": "",
140
+ "lineage": "", "verified": False,
141
+ }
142
+
143
+ try:
144
+ resp = _ncbi_get(NCBI_FETCH_URL, {
145
+ "db": "taxonomy", "id": taxon_id,
146
+ "retmode": "xml", "tool": NCBI_TOOL, "email": NCBI_EMAIL,
147
+ })
148
+
149
+ root = ET.fromstring(resp.text)
150
+ taxon_el = root.find("Taxon")
151
+
152
+ if taxon_el is None:
153
+ base["error"] = "Taxon element not found in NCBI XML"
154
+ return base
155
+
156
+ def txt(tag: str) -> str:
157
+ el = taxon_el.find(tag)
158
+ return (el.text or "").strip() if el is not None else ""
159
+
160
+ lineage_parts = [
161
+ (a.findtext("ScientificName") or "").strip()
162
+ for a in taxon_el.findall("./LineageEx/Taxon")
163
+ ]
164
+
165
+ common = (taxon_el.findtext("OtherNames/CommonName") or
166
+ taxon_el.findtext("CommonName") or "")
167
+
168
+ info = {
169
+ **base,
170
+ "scientific_name": txt("ScientificName"),
171
+ "common_name": common.strip(),
172
+ "rank": txt("Rank"),
173
+ "division": txt("Division"),
174
+ "lineage": " > ".join(p for p in lineage_parts if p),
175
+ "verified": True,
176
+ }
177
+
178
+ _id_to_info_cache[taxon_id] = info
179
+ logger.info("[taxonomy] Resolved taxon %d β†’ %s", taxon_id, info["scientific_name"])
180
+ return info
181
+
182
+ except ET.ParseError as exc:
183
+ logger.error("[taxonomy] XML parse error for taxon %d: %s", taxon_id, exc)
184
+ base["error"] = f"XML parse error: {exc}"
185
+ return base
186
+ except Exception as exc:
187
+ logger.error("[taxonomy] get_taxon_info(%d) failed: %s", taxon_id, exc)
188
+ base["error"] = str(exc)
189
+ return base
190
+
191
+
192
+ def resolve_taxon(taxon_id: int, top50_taxa: list) -> dict:
193
+ """Check training-set membership for a taxon ID."""
194
+ info = get_taxon_info(taxon_id)
195
+ in_training = taxon_id in top50_taxa
196
+ return {
197
+ **info,
198
+ "in_training": in_training,
199
+ "training_status": "in_training_data" if in_training else "unknown_species_fallback",
200
+ }