FunGO v2.0 backend
Browse files- app.py +278 -0
- config.py +104 -0
- embedder.py +187 -0
- filter.py +152 -0
- hf_README.md +41 -0
- predictor.py +216 -0
- requirements.txt +9 -0
- taxonomy.py +200 -0
app.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py β FunGO HuggingFace Space
|
| 2 |
+
"""
|
| 3 |
+
FunGO v2.0 β HuggingFace Spaces Deployment
|
| 4 |
+
=============================================
|
| 5 |
+
Flask API running on port 7860.
|
| 6 |
+
Model files loaded from /data/ (HF persistent storage).
|
| 7 |
+
|
| 8 |
+
To upload model files:
|
| 9 |
+
pip install huggingface_hub
|
| 10 |
+
huggingface-cli login
|
| 11 |
+
huggingface-cli upload Muteeba/FunGO ./pipeline_outputs/models /data/models --repo-type=space
|
| 12 |
+
huggingface-cli upload Muteeba/FunGO ./pipeline_outputs/labels /data/labels --repo-type=space
|
| 13 |
+
huggingface-cli upload Muteeba/FunGO ./pipeline_outputs/go_data /data/go_data --repo-type=space
|
| 14 |
+
huggingface-cli upload Muteeba/FunGO ./pipeline_outputs/features /data/features --repo-type=space
|
| 15 |
+
huggingface-cli upload Muteeba/FunGO /mnt/e/repeat/embeddings/model_cache /data/esm2_cache --repo-type=space
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import csv
|
| 19 |
+
import io
|
| 20 |
+
import logging
|
| 21 |
+
import os
|
| 22 |
+
import re as _re
|
| 23 |
+
import sys
|
| 24 |
+
import time
|
| 25 |
+
from collections import OrderedDict
|
| 26 |
+
|
| 27 |
+
# ββ HuggingFace paths βββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
+
os.environ.setdefault("FUNGO_PKL_DIR", "/data/models")
|
| 29 |
+
os.environ.setdefault("FUNGO_VOCAB_PKL", "/data/labels/vocabularies.pkl")
|
| 30 |
+
os.environ.setdefault("FUNGO_IA_PKL", "/data/go_data/ia_weights.pkl")
|
| 31 |
+
os.environ.setdefault("FUNGO_FEAT_META", "/data/features/feature_metadata.json")
|
| 32 |
+
os.environ.setdefault("FUNGO_MODEL_CACHE","/data/esm2_cache")
|
| 33 |
+
os.environ.setdefault("FUNGO_EMB_CACHE", "/data/embedding_cache")
|
| 34 |
+
os.environ.setdefault("FUNGO_OFFLINE", "1")
|
| 35 |
+
os.environ.setdefault("FUNGO_DEBUG", "0")
|
| 36 |
+
os.environ.setdefault("FUNGO_PORT", "7860")
|
| 37 |
+
|
| 38 |
+
from flask import Flask, jsonify, request, Response
|
| 39 |
+
from flask_cors import CORS
|
| 40 |
+
|
| 41 |
+
import config
|
| 42 |
+
import predictor
|
| 43 |
+
import embedder
|
| 44 |
+
import filter as flt
|
| 45 |
+
import taxonomy
|
| 46 |
+
|
| 47 |
+
logging.basicConfig(
|
| 48 |
+
level=logging.INFO,
|
| 49 |
+
format="%(asctime)s [%(levelname)s] %(name)s β %(message)s",
|
| 50 |
+
datefmt="%H:%M:%S",
|
| 51 |
+
)
|
| 52 |
+
log = logging.getLogger("fungo.app")
|
| 53 |
+
|
| 54 |
+
app = Flask(__name__)
|
| 55 |
+
CORS(app)
|
| 56 |
+
app.config["MAX_CONTENT_LENGTH"] = 2 * 1024 * 1024
|
| 57 |
+
|
| 58 |
+
_csv_store: OrderedDict = OrderedDict()
|
| 59 |
+
_CSV_MAX = 50
|
| 60 |
+
|
| 61 |
+
def _store_csv(job_id, predictions):
|
| 62 |
+
if len(_csv_store) >= _CSV_MAX:
|
| 63 |
+
_csv_store.popitem(last=False)
|
| 64 |
+
_csv_store[job_id] = {"predictions": predictions, "ts": time.time()}
|
| 65 |
+
|
| 66 |
+
def _make_csv(predictions):
|
| 67 |
+
out = io.StringIO()
|
| 68 |
+
w = csv.writer(out)
|
| 69 |
+
w.writerow(["protein_id","go_term","ontology","ontology_label",
|
| 70 |
+
"tier","tier_label","confidence","ia_weight","combined_score","threshold"])
|
| 71 |
+
for pid, data in predictions.items():
|
| 72 |
+
for p in data.get("all", []):
|
| 73 |
+
w.writerow([pid, p.get("go_term",""), p.get("ontology",""),
|
| 74 |
+
p.get("ontology_label",""), p.get("tier",""), p.get("tier_label",""),
|
| 75 |
+
p.get("confidence",""), p.get("ia_weight",""),
|
| 76 |
+
p.get("combined_score",""), p.get("threshold","")])
|
| 77 |
+
return out.getvalue()
|
| 78 |
+
|
| 79 |
+
_OX_RE = _re.compile(r"OX=(\d+)")
|
| 80 |
+
|
| 81 |
+
def _parse_taxon_id(header):
|
| 82 |
+
m = _OX_RE.search(header or "")
|
| 83 |
+
return int(m.group(1)) if m else None
|
| 84 |
+
|
| 85 |
+
def parse_fasta(fasta_text):
|
| 86 |
+
proteins, current_id, current_hdr, current_seq = [], None, None, []
|
| 87 |
+
for raw_line in fasta_text.splitlines():
|
| 88 |
+
line = raw_line.strip()
|
| 89 |
+
if not line: continue
|
| 90 |
+
if line.startswith(">"):
|
| 91 |
+
if current_id is not None:
|
| 92 |
+
seq = "".join(current_seq).upper()
|
| 93 |
+
if seq:
|
| 94 |
+
proteins.append({"id": current_id, "seq": seq,
|
| 95 |
+
"header": current_hdr, "taxon_id": _parse_taxon_id(current_hdr)})
|
| 96 |
+
current_hdr = line[1:].strip()
|
| 97 |
+
parts = current_hdr.split("|")
|
| 98 |
+
current_id = parts[1] if len(parts) >= 3 else current_hdr.split()[0]
|
| 99 |
+
current_seq = []
|
| 100 |
+
else:
|
| 101 |
+
current_seq.append(line)
|
| 102 |
+
if current_id is not None:
|
| 103 |
+
seq = "".join(current_seq).upper()
|
| 104 |
+
if seq:
|
| 105 |
+
proteins.append({"id": current_id, "seq": seq,
|
| 106 |
+
"header": current_hdr, "taxon_id": _parse_taxon_id(current_hdr)})
|
| 107 |
+
if not proteins:
|
| 108 |
+
raise ValueError("No valid protein sequences found in FASTA input.")
|
| 109 |
+
return proteins
|
| 110 |
+
|
| 111 |
+
def _run_prediction(fasta_text, taxon_id_override):
|
| 112 |
+
proteins = parse_fasta(fasta_text)
|
| 113 |
+
if len(proteins) > config.MAX_SEQUENCES:
|
| 114 |
+
raise ValueError(f"Too many sequences. Max: {config.MAX_SEQUENCES}.")
|
| 115 |
+
protein_ids = [p["id"] for p in proteins]
|
| 116 |
+
sequences = [p["seq"] for p in proteins]
|
| 117 |
+
taxon_ids = [taxon_id_override if taxon_id_override is not None
|
| 118 |
+
else p["taxon_id"] for p in proteins]
|
| 119 |
+
log.info("Proteins: %s | Taxon IDs: %s", protein_ids, taxon_ids)
|
| 120 |
+
t0 = time.perf_counter()
|
| 121 |
+
X_esm = embedder.extract(sequences)
|
| 122 |
+
top50 = predictor.get_top50_taxa()
|
| 123 |
+
X_final = embedder.build_features(X_esm, taxon_ids, top50)
|
| 124 |
+
raw_preds = predictor.predict(X_final, protein_ids)
|
| 125 |
+
ia_weights = predictor.get_ia_weights()
|
| 126 |
+
for p in raw_preds:
|
| 127 |
+
p["ia_weight"] = round(float(ia_weights.get(p["go_term"], 0.0)), 4)
|
| 128 |
+
return proteins, raw_preds, ia_weights, round(time.perf_counter() - t0, 2)
|
| 129 |
+
|
| 130 |
+
@app.route("/health", methods=["GET"])
|
| 131 |
+
def health():
|
| 132 |
+
return jsonify({"status":"ok","device":config.DEVICE,"fp16":config.USE_FP16,"version":"2.0.0"})
|
| 133 |
+
|
| 134 |
+
@app.route("/model/info", methods=["GET"])
|
| 135 |
+
def model_info():
|
| 136 |
+
try: stats = predictor.get_model_stats()
|
| 137 |
+
except RuntimeError as e: return jsonify({"error": str(e)}), 503
|
| 138 |
+
return jsonify({"device":config.DEVICE,"fp16":config.USE_FP16,
|
| 139 |
+
"model_name":config.MODEL_NAME,"ontologies":stats,
|
| 140 |
+
"top50_taxa_count":len(predictor.get_top50_taxa()),
|
| 141 |
+
"thresholds":{
|
| 142 |
+
"STRONG": {"min_ia":config.TIER_GOLD_IA, "min_conf":config.TIER_GOLD_CONF},
|
| 143 |
+
"MODERATE": {"min_ia":config.TIER_GOOD_IA, "min_conf":config.TIER_GOOD_CONF},
|
| 144 |
+
"INDICATIVE":{"min_ia":config.TIER_SILVER_IA, "min_conf":config.TIER_SILVER_CONF},
|
| 145 |
+
},"display_limit":flt.TOP_N_DISPLAY})
|
| 146 |
+
|
| 147 |
+
@app.route("/taxonomy/search", methods=["GET"])
|
| 148 |
+
def taxonomy_search():
|
| 149 |
+
q = request.args.get("q","").strip()
|
| 150 |
+
if len(q) < 2: return jsonify({"error":"Query must be at least 2 characters."}), 400
|
| 151 |
+
try: max_r = min(int(request.args.get("max_results",8)),20)
|
| 152 |
+
except: max_r = 8
|
| 153 |
+
return jsonify({"query":q,"results":taxonomy.search_species(q,max_results=max_r)})
|
| 154 |
+
|
| 155 |
+
@app.route("/taxonomy/verify", methods=["GET"])
|
| 156 |
+
def taxonomy_verify():
|
| 157 |
+
raw = request.args.get("taxon_id","")
|
| 158 |
+
if not raw: return jsonify({"error":"taxon_id required."}), 400
|
| 159 |
+
try: taxon_id = int(raw)
|
| 160 |
+
except: return jsonify({"error":f"Invalid taxon_id: '{raw}'"}), 400
|
| 161 |
+
return jsonify(taxonomy.resolve_taxon(taxon_id, predictor.get_top50_taxa()))
|
| 162 |
+
|
| 163 |
+
@app.route("/predict", methods=["POST"])
|
| 164 |
+
def predict():
|
| 165 |
+
if not request.is_json: return jsonify({"error":"Content-Type must be application/json."}), 415
|
| 166 |
+
body = request.get_json(silent=True) or {}
|
| 167 |
+
fasta_text = body.get("fasta","").strip()
|
| 168 |
+
if not fasta_text: return jsonify({"error":"'fasta' field is required."}), 400
|
| 169 |
+
taxon_id_override = None
|
| 170 |
+
if "taxon_id" in body:
|
| 171 |
+
try: taxon_id_override = int(body["taxon_id"])
|
| 172 |
+
except: return jsonify({"error":f"Invalid taxon_id"}), 400
|
| 173 |
+
try:
|
| 174 |
+
proteins, raw_preds, ia_weights, elapsed = _run_prediction(fasta_text, taxon_id_override)
|
| 175 |
+
except ValueError as e: return jsonify({"error": str(e)}), 400
|
| 176 |
+
except RuntimeError as e: return jsonify({"error": str(e)}), 503
|
| 177 |
+
except Exception as e:
|
| 178 |
+
log.exception("Prediction error"); return jsonify({"error": str(e)}), 500
|
| 179 |
+
|
| 180 |
+
protein_ids = [p["id"] for p in proteins]
|
| 181 |
+
raw_by_pid = {pid:[] for pid in protein_ids}
|
| 182 |
+
for pred in raw_preds: raw_by_pid[pred["protein_id"]].append(pred)
|
| 183 |
+
|
| 184 |
+
predictions, csv_data, total_display, total_all = {}, {}, 0, 0
|
| 185 |
+
for prot in proteins:
|
| 186 |
+
pid = prot["id"]
|
| 187 |
+
res = flt.filter_predictions(raw_by_pid[pid], ia_weights)
|
| 188 |
+
display, all_f = res["display"], res["all"]
|
| 189 |
+
total_display += len(display); total_all += len(all_f)
|
| 190 |
+
predictions[pid] = {"taxon_id":prot["taxon_id"],
|
| 191 |
+
"summary":flt.summarise(display,all_f,pid),
|
| 192 |
+
"display":display,"total_all":len(all_f)}
|
| 193 |
+
csv_data[pid] = {"all":all_f}
|
| 194 |
+
|
| 195 |
+
job_id = str(int(time.time()*1000))
|
| 196 |
+
_store_csv(job_id, csv_data)
|
| 197 |
+
return jsonify({"job_id":job_id,
|
| 198 |
+
"metadata":{"n_proteins":len(protein_ids),"device":config.DEVICE,
|
| 199 |
+
"total_raw_predictions":len(raw_preds),"total_filtered":total_all,
|
| 200 |
+
"total_displayed":total_display,"display_limit":flt.TOP_N_DISPLAY,
|
| 201 |
+
"elapsed_seconds":elapsed},
|
| 202 |
+
"predictions":predictions})
|
| 203 |
+
|
| 204 |
+
@app.route("/predict/csv", methods=["GET"])
|
| 205 |
+
def download_csv():
|
| 206 |
+
job_id = request.args.get("job_id","").strip()
|
| 207 |
+
if not job_id: return jsonify({"error":"job_id required."}), 400
|
| 208 |
+
job = _csv_store.get(job_id)
|
| 209 |
+
if not job: return jsonify({"error":f"Job '{job_id}' not found."}), 404
|
| 210 |
+
return Response(_make_csv(job["predictions"]), mimetype="text/csv",
|
| 211 |
+
headers={"Content-Disposition":f"attachment; filename=fungo_{job_id}.csv"})
|
| 212 |
+
|
| 213 |
+
@app.route("/predict/debug", methods=["POST"])
|
| 214 |
+
def predict_debug():
|
| 215 |
+
if not request.is_json: return jsonify({"error":"Content-Type must be application/json."}), 415
|
| 216 |
+
body = request.get_json(silent=True) or {}
|
| 217 |
+
fasta_text = body.get("fasta","").strip()
|
| 218 |
+
if not fasta_text: return jsonify({"error":"'fasta' required."}), 400
|
| 219 |
+
taxon_id_override = None
|
| 220 |
+
if "taxon_id" in body:
|
| 221 |
+
try: taxon_id_override = int(body["taxon_id"])
|
| 222 |
+
except: return jsonify({"error":"Invalid taxon_id"}), 400
|
| 223 |
+
try:
|
| 224 |
+
proteins, raw_preds, ia_weights, elapsed = _run_prediction(fasta_text, taxon_id_override)
|
| 225 |
+
except ValueError as e: return jsonify({"error":str(e)}), 400
|
| 226 |
+
except RuntimeError as e: return jsonify({"error":str(e)}), 503
|
| 227 |
+
except Exception as e:
|
| 228 |
+
log.exception("Debug error"); return jsonify({"error":str(e)}), 500
|
| 229 |
+
|
| 230 |
+
protein_ids = [p["id"] for p in proteins]
|
| 231 |
+
raw_by_pid = {pid:[] for pid in protein_ids}
|
| 232 |
+
for pred in raw_preds: raw_by_pid[pred["protein_id"]].append(pred)
|
| 233 |
+
|
| 234 |
+
thr = {"STRONG":{"min_ia":config.TIER_GOLD_IA,"min_conf":config.TIER_GOLD_CONF},
|
| 235 |
+
"MODERATE":{"min_ia":config.TIER_GOOD_IA,"min_conf":config.TIER_GOOD_CONF},
|
| 236 |
+
"INDICATIVE":{"min_ia":config.TIER_SILVER_IA,"min_conf":config.TIER_SILVER_CONF}}
|
| 237 |
+
predictions = {}
|
| 238 |
+
for prot in proteins:
|
| 239 |
+
pid = prot["id"]
|
| 240 |
+
res = flt.filter_predictions(raw_by_pid[pid], ia_weights)
|
| 241 |
+
display, all_f = res["display"], res["all"]
|
| 242 |
+
accepted = {p["go_term"] for p in all_f}
|
| 243 |
+
fo = []
|
| 244 |
+
for pred in raw_by_pid[pid]:
|
| 245 |
+
go = pred["go_term"]
|
| 246 |
+
if go in accepted: continue
|
| 247 |
+
ia, conf = pred.get("ia_weight", float(ia_weights.get(go,0.0))), pred["confidence"]
|
| 248 |
+
if go in config.BLACKLIST_TERMS: reason="blacklisted"
|
| 249 |
+
elif ia <= config.TIER_SILVER_IA: reason=f"ia_too_low (ia={ia:.4f})"
|
| 250 |
+
elif conf < config.TIER_SILVER_CONF: reason=f"conf_too_low (conf={conf:.4f})"
|
| 251 |
+
else: reason="below_all_tiers"
|
| 252 |
+
fo.append({"go_term":go,"ontology":pred["ontology"],"confidence":conf,
|
| 253 |
+
"ia_weight":ia,"reason":reason})
|
| 254 |
+
fo.sort(key=lambda x:-x["ia_weight"])
|
| 255 |
+
predictions[pid] = {"taxon_id":prot["taxon_id"],
|
| 256 |
+
"summary":flt.summarise(display,all_f,pid),
|
| 257 |
+
"display":display,"all_filtered":all_f,
|
| 258 |
+
"filtered_out":fo,"thresholds_used":thr}
|
| 259 |
+
return jsonify({"metadata":{"n_proteins":len(protein_ids),"device":config.DEVICE,
|
| 260 |
+
"total_raw":len(raw_preds),"elapsed_seconds":elapsed},"predictions":predictions})
|
| 261 |
+
|
| 262 |
+
@app.errorhandler(404)
|
| 263 |
+
def not_found(e): return jsonify({"error":"Not found."}), 404
|
| 264 |
+
@app.errorhandler(413)
|
| 265 |
+
def too_large(e): return jsonify({"error":"Request too large."}), 413
|
| 266 |
+
@app.errorhandler(500)
|
| 267 |
+
def internal(e):
|
| 268 |
+
log.exception("Unhandled error"); return jsonify({"error":"Internal server error."}), 500
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
log.info("FunGO v2.0 β HuggingFace Space starting β¦")
|
| 272 |
+
config.ensure_dirs()
|
| 273 |
+
if not config.validate_paths():
|
| 274 |
+
log.error("Model paths missing!")
|
| 275 |
+
sys.exit(1)
|
| 276 |
+
predictor.load_all()
|
| 277 |
+
log.info("Models loaded. Serving on port 7860 β¦")
|
| 278 |
+
app.run(host="0.0.0.0", port=7860, debug=False)
|
config.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.py
|
| 2 |
+
"""
|
| 3 |
+
FunGO Backend β Central Configuration
|
| 4 |
+
======================================
|
| 5 |
+
ONLY change paths in this file. Nothing else needs editing.
|
| 6 |
+
|
| 7 |
+
How to use:
|
| 8 |
+
- Update PKL_DIR, VOCAB_PKL, IA_PKL, FEAT_META to point to your model files
|
| 9 |
+
- Update MODEL_CACHE_DIR to point to your ESM2 weights cache
|
| 10 |
+
- All other settings work as-is
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import logging
|
| 14 |
+
import os
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("config")
|
| 19 |
+
|
| 20 |
+
# ββ DEVICE (auto-detected) ββββββββββββββββββββββββββββββββββββ
|
| 21 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
USE_FP16 = DEVICE == "cuda"
|
| 23 |
+
|
| 24 |
+
# ββ MODEL PATHS β UPDATE THESE TO MATCH YOUR SYSTEM ββββββββββ
|
| 25 |
+
PKL_DIR = Path(os.environ.get("FUNGO_PKL_DIR", "/mnt/f/research/thesis/pipeline_outputs/models"))
|
| 26 |
+
VOCAB_PKL = Path(os.environ.get("FUNGO_VOCAB_PKL", "/mnt/f/research/thesis/pipeline_outputs/labels/vocabularies.pkl"))
|
| 27 |
+
IA_PKL = Path(os.environ.get("FUNGO_IA_PKL", "/mnt/f/research/thesis/pipeline_outputs/go_data/ia_weights.pkl"))
|
| 28 |
+
FEAT_META = Path(os.environ.get("FUNGO_FEAT_META", "/mnt/f/research/thesis/pipeline_outputs/features/feature_metadata.json"))
|
| 29 |
+
|
| 30 |
+
# ββ ESM2 SETTINGS βββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
MODEL_CACHE_DIR = Path(os.environ.get("FUNGO_MODEL_CACHE", "/mnt/e/repeat/embeddings/model_cache"))
|
| 32 |
+
MODEL_NAME = "facebook/esm2_t36_3B_UR50D"
|
| 33 |
+
LAYERS_TO_USE = [30, 31, 32, 33, 34, 35]
|
| 34 |
+
MAX_SEQ_LENGTH = 1400
|
| 35 |
+
BATCH_SIZE = 4 if DEVICE == "cpu" else 16
|
| 36 |
+
TRANSFORMERS_OFFLINE = os.environ.get("FUNGO_OFFLINE", "1")
|
| 37 |
+
|
| 38 |
+
# ββ EMBEDDING CACHE βββββββββββββββββββββββββββββββββββββββββββ
|
| 39 |
+
EMB_CACHE_DIR = Path(os.environ.get("FUNGO_EMB_CACHE", "./embedding_cache"))
|
| 40 |
+
|
| 41 |
+
# ββ FILTER THRESHOLDS (do not change) ββββββββββββββββββββββββ
|
| 42 |
+
BLACKLIST_TERMS = {
|
| 43 |
+
"GO:0003674","GO:0008150","GO:0005575","GO:0005488",
|
| 44 |
+
"GO:0043226","GO:0043229","GO:0043227","GO:0043231",
|
| 45 |
+
"GO:0110165","GO:0005622","GO:0005623","GO:0044464",
|
| 46 |
+
"GO:0043232","GO:0044424","GO:0009987","GO:0065007",
|
| 47 |
+
"GO:0050794","GO:0019222","GO:0060255","GO:0080090",
|
| 48 |
+
"GO:0050789",
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# Strong Evidence (was GOLD)
|
| 52 |
+
TIER_GOLD_IA = 5.0
|
| 53 |
+
TIER_GOLD_CONF = 0.30
|
| 54 |
+
|
| 55 |
+
# Moderate Evidence (was GOOD)
|
| 56 |
+
TIER_GOOD_IA = 2.0
|
| 57 |
+
TIER_GOOD_CONF = 0.50
|
| 58 |
+
|
| 59 |
+
# Indicative (was SILVER)
|
| 60 |
+
TIER_SILVER_IA = 1.0
|
| 61 |
+
TIER_SILVER_CONF = 0.65
|
| 62 |
+
|
| 63 |
+
# ββ NCBI TAXONOMY API βββββββββββββββββββββββββββββββββββββββββ
|
| 64 |
+
NCBI_SEARCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
|
| 65 |
+
NCBI_SUMMARY_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi"
|
| 66 |
+
NCBI_FETCH_URL = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
|
| 67 |
+
NCBI_TOOL = "FunGO"
|
| 68 |
+
NCBI_EMAIL = "fungo@research.com"
|
| 69 |
+
|
| 70 |
+
# ββ FLASK βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
+
PORT = int(os.environ.get("FUNGO_PORT", 5000))
|
| 72 |
+
DEBUG = os.environ.get("FUNGO_DEBUG", "0") == "1"
|
| 73 |
+
MAX_SEQUENCES = int(os.environ.get("FUNGO_MAX_SEQ", 10))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ββ Runtime helpers βββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
+
|
| 78 |
+
def ensure_dirs():
|
| 79 |
+
"""Create required runtime directories. Called once at startup."""
|
| 80 |
+
EMB_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 81 |
+
logger.info("[config] EMB_CACHE_DIR ready β %s", EMB_CACHE_DIR)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def validate_paths() -> bool:
|
| 85 |
+
"""
|
| 86 |
+
Check that all required model files exist.
|
| 87 |
+
Returns True if all found, False if any missing.
|
| 88 |
+
Called at startup before loading models.
|
| 89 |
+
"""
|
| 90 |
+
required = {
|
| 91 |
+
"PKL_DIR": PKL_DIR,
|
| 92 |
+
"VOCAB_PKL": VOCAB_PKL,
|
| 93 |
+
"IA_PKL": IA_PKL,
|
| 94 |
+
"FEAT_META": FEAT_META,
|
| 95 |
+
"MODEL_CACHE_DIR": MODEL_CACHE_DIR,
|
| 96 |
+
}
|
| 97 |
+
all_ok = True
|
| 98 |
+
for name, path in required.items():
|
| 99 |
+
if path.exists():
|
| 100 |
+
logger.info("[config] β %-18s β %s", name, path)
|
| 101 |
+
else:
|
| 102 |
+
logger.error("[config] β %-18s β %s (NOT FOUND)", name, path)
|
| 103 |
+
all_ok = False
|
| 104 |
+
return all_ok
|
embedder.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# embedder.py
|
| 2 |
+
"""
|
| 3 |
+
FunGO β ESM2 Embedding Extractor
|
| 4 |
+
==================================
|
| 5 |
+
Extracts layers 30β35 from ESM2-t36-3B.
|
| 6 |
+
- Auto-detects CPU vs GPU
|
| 7 |
+
- Caches embeddings per session to avoid re-extraction
|
| 8 |
+
- Lazy model loading (loaded only on first request)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import hashlib
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from config import (
|
| 17 |
+
MODEL_CACHE_DIR, MODEL_NAME, LAYERS_TO_USE,
|
| 18 |
+
MAX_SEQ_LENGTH, BATCH_SIZE, DEVICE, USE_FP16,
|
| 19 |
+
EMB_CACHE_DIR,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
| 23 |
+
os.environ["HF_DATASETS_OFFLINE"] = "1"
|
| 24 |
+
os.environ["TRANSFORMERS_CACHE"] = str(MODEL_CACHE_DIR)
|
| 25 |
+
os.environ["HF_HOME"] = str(MODEL_CACHE_DIR)
|
| 26 |
+
|
| 27 |
+
N_ESM_DIMS = len(LAYERS_TO_USE) * 2560 # 6 Γ 2560 = 15,360
|
| 28 |
+
|
| 29 |
+
# ββ Lazy globals ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
+
_tokenizer = None
|
| 31 |
+
_model = None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _load_model():
|
| 35 |
+
"""Load ESM2 tokenizer and model (only once)."""
|
| 36 |
+
global _tokenizer, _model
|
| 37 |
+
|
| 38 |
+
if _tokenizer is not None and _model is not None:
|
| 39 |
+
return _tokenizer, _model
|
| 40 |
+
|
| 41 |
+
print(f"[embedder] Loading ESM2 from local cache β {MODEL_CACHE_DIR}")
|
| 42 |
+
print(f"[embedder] Device: {DEVICE} | FP16: {USE_FP16}")
|
| 43 |
+
|
| 44 |
+
from transformers import EsmTokenizer, EsmModel
|
| 45 |
+
|
| 46 |
+
_tokenizer = EsmTokenizer.from_pretrained(
|
| 47 |
+
MODEL_NAME,
|
| 48 |
+
cache_dir=MODEL_CACHE_DIR,
|
| 49 |
+
local_files_only=True,
|
| 50 |
+
)
|
| 51 |
+
_model = EsmModel.from_pretrained(
|
| 52 |
+
MODEL_NAME,
|
| 53 |
+
cache_dir=MODEL_CACHE_DIR,
|
| 54 |
+
output_hidden_states=True,
|
| 55 |
+
local_files_only=True,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
if USE_FP16:
|
| 59 |
+
_model = _model.to(DEVICE).half()
|
| 60 |
+
else:
|
| 61 |
+
_model = _model.to(DEVICE)
|
| 62 |
+
|
| 63 |
+
_model.eval()
|
| 64 |
+
for p in _model.parameters():
|
| 65 |
+
p.requires_grad = False
|
| 66 |
+
|
| 67 |
+
print(f"[embedder] Model ready on {DEVICE}")
|
| 68 |
+
return _tokenizer, _model
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _seq_cache_key(sequences: list) -> str:
|
| 72 |
+
"""Hash sequences to use as cache filename."""
|
| 73 |
+
joined = "|".join(f"{s[:50]}{len(s)}" for s in sequences)
|
| 74 |
+
return hashlib.md5(joined.encode()).hexdigest()[:16]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _load_cache(key: str):
|
| 78 |
+
path = EMB_CACHE_DIR / f"{key}.npy"
|
| 79 |
+
if path.exists():
|
| 80 |
+
return np.load(str(path))
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _save_cache(key: str, arr: np.ndarray):
|
| 85 |
+
np.save(str(EMB_CACHE_DIR / f"{key}.npy"), arr)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def extract(sequences: list) -> np.ndarray:
|
| 89 |
+
"""
|
| 90 |
+
Extract ESM2 embeddings for a list of sequences.
|
| 91 |
+
Returns np.ndarray of shape (N, 15360), dtype float32.
|
| 92 |
+
Sequences are truncated to MAX_SEQ_LENGTH if needed.
|
| 93 |
+
Uses cache to avoid re-extraction.
|
| 94 |
+
"""
|
| 95 |
+
# Truncate sequences
|
| 96 |
+
seqs_truncated = [s[:MAX_SEQ_LENGTH] for s in sequences]
|
| 97 |
+
N = len(seqs_truncated)
|
| 98 |
+
|
| 99 |
+
# Check cache
|
| 100 |
+
cache_key = _seq_cache_key(seqs_truncated)
|
| 101 |
+
cached_emb = _load_cache(cache_key)
|
| 102 |
+
if cached_emb is not None and cached_emb.shape == (N, N_ESM_DIMS):
|
| 103 |
+
print(f"[embedder] Cache hit β skipping extraction for {N} sequences")
|
| 104 |
+
return cached_emb.astype(np.float32)
|
| 105 |
+
|
| 106 |
+
print(f"[embedder] Extracting embeddings: {N} sequences on {DEVICE}")
|
| 107 |
+
|
| 108 |
+
tokenizer, model = _load_model()
|
| 109 |
+
|
| 110 |
+
X = np.zeros((N, N_ESM_DIMS), dtype=np.float32)
|
| 111 |
+
current_batch = BATCH_SIZE
|
| 112 |
+
|
| 113 |
+
with torch.no_grad():
|
| 114 |
+
i = 0
|
| 115 |
+
while i < N:
|
| 116 |
+
batch_end = min(i + current_batch, N)
|
| 117 |
+
batch_seqs = seqs_truncated[i:batch_end]
|
| 118 |
+
|
| 119 |
+
try:
|
| 120 |
+
inputs = tokenizer(
|
| 121 |
+
batch_seqs,
|
| 122 |
+
return_tensors="pt",
|
| 123 |
+
padding=True,
|
| 124 |
+
truncation=True,
|
| 125 |
+
max_length=MAX_SEQ_LENGTH + 2,
|
| 126 |
+
)
|
| 127 |
+
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 128 |
+
|
| 129 |
+
outputs = model(**inputs)
|
| 130 |
+
hidden_states = outputs.hidden_states
|
| 131 |
+
|
| 132 |
+
for j, seq in enumerate(batch_seqs):
|
| 133 |
+
seq_len = len(seq)
|
| 134 |
+
layer_vecs = []
|
| 135 |
+
|
| 136 |
+
for layer_idx in LAYERS_TO_USE:
|
| 137 |
+
h = hidden_states[layer_idx][j, 1:seq_len + 1, :]
|
| 138 |
+
v = h.mean(dim=0)
|
| 139 |
+
if DEVICE == "cuda":
|
| 140 |
+
v = v.float().cpu().numpy()
|
| 141 |
+
else:
|
| 142 |
+
v = v.numpy()
|
| 143 |
+
layer_vecs.append(v)
|
| 144 |
+
|
| 145 |
+
X[i + j] = np.concatenate(layer_vecs)
|
| 146 |
+
|
| 147 |
+
i += len(batch_seqs)
|
| 148 |
+
print(f"[embedder] {i}/{N} done")
|
| 149 |
+
|
| 150 |
+
except RuntimeError as e:
|
| 151 |
+
if "out of memory" in str(e).lower() and current_batch > 1:
|
| 152 |
+
current_batch = max(1, current_batch // 2)
|
| 153 |
+
print(f"[embedder] OOM β batch size reduced to {current_batch}")
|
| 154 |
+
if DEVICE == "cuda":
|
| 155 |
+
torch.cuda.empty_cache()
|
| 156 |
+
else:
|
| 157 |
+
raise
|
| 158 |
+
|
| 159 |
+
# Sanitise
|
| 160 |
+
bad = np.isnan(X).sum() + np.isinf(X).sum()
|
| 161 |
+
if bad > 0:
|
| 162 |
+
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
|
| 163 |
+
|
| 164 |
+
# Save cache
|
| 165 |
+
_save_cache(cache_key, X)
|
| 166 |
+
print(f"[embedder] Saved to cache: {cache_key}")
|
| 167 |
+
|
| 168 |
+
return X
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def build_features(X_esm: np.ndarray, taxon_ids: list,
|
| 172 |
+
top50_taxa: list) -> np.ndarray:
|
| 173 |
+
"""
|
| 174 |
+
Append 51-dim taxonomy features to ESM embeddings.
|
| 175 |
+
Returns (N, 15411) feature matrix.
|
| 176 |
+
"""
|
| 177 |
+
N = X_esm.shape[0]
|
| 178 |
+
taxon_to_i = {t: i for i, t in enumerate(top50_taxa)}
|
| 179 |
+
X_tax = np.zeros((N, 51), dtype=np.float32)
|
| 180 |
+
|
| 181 |
+
for i, tx in enumerate(taxon_ids):
|
| 182 |
+
if tx is not None and tx in taxon_to_i:
|
| 183 |
+
X_tax[i, taxon_to_i[tx]] = 1.0
|
| 184 |
+
else:
|
| 185 |
+
X_tax[i, 50] = 1.0 # unknown species flag
|
| 186 |
+
|
| 187 |
+
return np.hstack([X_esm, X_tax])
|
filter.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# filter.py
|
| 2 |
+
"""
|
| 3 |
+
FunGO β Smart Tier Filtering
|
| 4 |
+
==============================
|
| 5 |
+
Removes generic/root GO terms and assigns evidence tiers
|
| 6 |
+
to remaining predictions.
|
| 7 |
+
|
| 8 |
+
Changes from original:
|
| 9 |
+
1. Tier names updated:
|
| 10 |
+
GOLD β STRONG (Strong Evidence)
|
| 11 |
+
GOOD β MODERATE (Moderate Evidence)
|
| 12 |
+
SILVER β INDICATIVE
|
| 13 |
+
2. Combined score = ia_weight Γ confidence
|
| 14 |
+
Used for ranking β more scientifically sound.
|
| 15 |
+
3. filter_predictions() returns a dict with two keys:
|
| 16 |
+
"display" β top 20 by combined score (for UI screen)
|
| 17 |
+
"all" β full filtered list (for CSV download)
|
| 18 |
+
4. summarise() updated to use new tier keys.
|
| 19 |
+
5. Blacklist + IA/confidence thresholds β completely unchanged.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import logging
|
| 23 |
+
from config import (
|
| 24 |
+
BLACKLIST_TERMS,
|
| 25 |
+
TIER_GOLD_IA, TIER_GOLD_CONF,
|
| 26 |
+
TIER_GOOD_IA, TIER_GOOD_CONF,
|
| 27 |
+
TIER_SILVER_IA, TIER_SILVER_CONF,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
ONT_LABELS = {
|
| 33 |
+
"MFO": "Molecular Function",
|
| 34 |
+
"BPO": "Biological Process",
|
| 35 |
+
"CCO": "Cellular Component",
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
TIER_LABELS = {
|
| 39 |
+
"STRONG": "Strong Evidence",
|
| 40 |
+
"MODERATE": "Moderate Evidence",
|
| 41 |
+
"INDICATIVE": "Indicative",
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
TIER_RANK = {"STRONG": 0, "MODERATE": 1, "INDICATIVE": 2}
|
| 45 |
+
|
| 46 |
+
# Max predictions shown on screen per protein
|
| 47 |
+
TOP_N_DISPLAY = 20
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def assign_tier(go_term: str, ia: float, confidence: float) -> str:
|
| 51 |
+
"""
|
| 52 |
+
Assign evidence tier. Thresholds unchanged from original.
|
| 53 |
+
|
| 54 |
+
Returns: "STRONG" | "MODERATE" | "INDICATIVE" | "NOISE"
|
| 55 |
+
"""
|
| 56 |
+
if go_term in BLACKLIST_TERMS:
|
| 57 |
+
return "NOISE"
|
| 58 |
+
if ia > TIER_GOLD_IA and confidence >= TIER_GOLD_CONF:
|
| 59 |
+
return "STRONG"
|
| 60 |
+
if ia > TIER_GOOD_IA and confidence >= TIER_GOOD_CONF:
|
| 61 |
+
return "MODERATE"
|
| 62 |
+
if ia > TIER_SILVER_IA and confidence >= TIER_SILVER_CONF:
|
| 63 |
+
return "INDICATIVE"
|
| 64 |
+
return "NOISE"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def combined_score(ia: float, confidence: float) -> float:
|
| 68 |
+
"""
|
| 69 |
+
Ranking score = ia_weight Γ confidence.
|
| 70 |
+
Balances specificity (IA) and model certainty (confidence).
|
| 71 |
+
"""
|
| 72 |
+
return round(ia * confidence, 6)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def filter_predictions(raw_predictions: list, ia_weights: dict) -> dict:
|
| 76 |
+
"""
|
| 77 |
+
Filter raw predictions and return display + full sets.
|
| 78 |
+
|
| 79 |
+
Returns
|
| 80 |
+
-------
|
| 81 |
+
{
|
| 82 |
+
"display": top-20 predictions (sorted by combined_score desc),
|
| 83 |
+
"all": all filtered predictions (for CSV)
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
Each prediction dict contains:
|
| 87 |
+
go_term, ontology, ontology_label, confidence, threshold,
|
| 88 |
+
ia_weight, combined_score, tier, tier_rank, tier_label
|
| 89 |
+
"""
|
| 90 |
+
filtered = []
|
| 91 |
+
|
| 92 |
+
for pred in raw_predictions:
|
| 93 |
+
go_term = pred["go_term"]
|
| 94 |
+
confidence = pred["confidence"]
|
| 95 |
+
ia = float(ia_weights.get(go_term, 0.0))
|
| 96 |
+
tier = assign_tier(go_term, ia, confidence)
|
| 97 |
+
|
| 98 |
+
if tier == "NOISE":
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
if tier not in TIER_RANK:
|
| 102 |
+
logger.warning("Unknown tier %r for %s β skipping", tier, go_term)
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
score = combined_score(ia, confidence)
|
| 106 |
+
|
| 107 |
+
filtered.append({
|
| 108 |
+
**pred,
|
| 109 |
+
"ia_weight": round(ia, 4),
|
| 110 |
+
"combined_score": score,
|
| 111 |
+
"tier": tier,
|
| 112 |
+
"tier_rank": TIER_RANK[tier],
|
| 113 |
+
"tier_label": TIER_LABELS[tier],
|
| 114 |
+
"ontology_label": ONT_LABELS.get(pred["ontology"], pred["ontology"]),
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
# Sort by combined score descending, tier_rank as tiebreaker
|
| 118 |
+
filtered.sort(key=lambda x: (-x["combined_score"], x["tier_rank"]))
|
| 119 |
+
|
| 120 |
+
return {
|
| 121 |
+
"display": filtered[:TOP_N_DISPLAY],
|
| 122 |
+
"all": filtered,
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def summarise(filtered_display: list, all_filtered: list, protein_id: str) -> dict:
|
| 127 |
+
"""
|
| 128 |
+
Per-protein summary. Counts are over ALL filtered (not just top-20).
|
| 129 |
+
"""
|
| 130 |
+
ont_counts = {"MFO": 0, "BPO": 0, "CCO": 0}
|
| 131 |
+
tier_counts = {"STRONG": 0, "MODERATE": 0, "INDICATIVE": 0}
|
| 132 |
+
|
| 133 |
+
for p in all_filtered:
|
| 134 |
+
ont = p.get("ontology", "")
|
| 135 |
+
if ont in ont_counts:
|
| 136 |
+
ont_counts[ont] += 1
|
| 137 |
+
t = p.get("tier", "")
|
| 138 |
+
if t in tier_counts:
|
| 139 |
+
tier_counts[t] += 1
|
| 140 |
+
|
| 141 |
+
n = len(all_filtered)
|
| 142 |
+
return {
|
| 143 |
+
"protein_id": protein_id,
|
| 144 |
+
"total_filtered": n,
|
| 145 |
+
"displayed": len(filtered_display),
|
| 146 |
+
"by_ontology": ont_counts,
|
| 147 |
+
"by_tier": tier_counts,
|
| 148 |
+
"has_strong_evidence": tier_counts["STRONG"] > 0,
|
| 149 |
+
"avg_confidence": round(sum(p["confidence"] for p in all_filtered) / n, 4) if n else 0.0,
|
| 150 |
+
"avg_ia": round(sum(p["ia_weight"] for p in all_filtered) / n, 4) if n else 0.0,
|
| 151 |
+
"avg_combined_score": round(sum(p["combined_score"] for p in all_filtered) / n, 4) if n else 0.0,
|
| 152 |
+
}
|
hf_README.md
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: FunGO
|
| 3 |
+
emoji: π§¬
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
short_description: Protein Function Prediction using ESM2 + XGBoost
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# FunGO β Protein Function Prediction
|
| 15 |
+
|
| 16 |
+
**Beyond Prediction β Understanding Function.**
|
| 17 |
+
|
| 18 |
+
FunGO predicts Gene Ontology (GO) terms for protein sequences using:
|
| 19 |
+
- **ESM2-t36-3B** β protein language model embeddings (layers 30β35)
|
| 20 |
+
- **XGBoost classifiers** β 4,133 GO-term specific models
|
| 21 |
+
- **Evidence-tiered filtering** β Strong / Moderate / Indicative
|
| 22 |
+
|
| 23 |
+
## Evidence Tiers
|
| 24 |
+
|
| 25 |
+
| Tier | IA Weight | Confidence | Description |
|
| 26 |
+
|------|-----------|------------|-------------|
|
| 27 |
+
| Strong Evidence | > 5.0 | β₯ 0.30 | Highly specific GO term |
|
| 28 |
+
| Moderate Evidence | > 2.0 | β₯ 0.50 | Moderately specific term |
|
| 29 |
+
| Indicative | > 1.0 | β₯ 0.65 | Lower specificity, high confidence |
|
| 30 |
+
|
| 31 |
+
## Ontologies Covered
|
| 32 |
+
|
| 33 |
+
- **MFO** β Molecular Function
|
| 34 |
+
- **BPO** β Biological Process
|
| 35 |
+
- **CCO** β Cellular Component
|
| 36 |
+
|
| 37 |
+
## Development Team
|
| 38 |
+
|
| 39 |
+
- **Dr. Beenish Maqsood** β Principal Investigator, School of Biochemistry and Biotechnology, University of the Punjab
|
| 40 |
+
- **Dr. Naeem Mahmood** β Co-Supervisor, School of Biochemistry and Biotechnology, University of the Punjab
|
| 41 |
+
- **Muteeba Azhar** β Lead Developer, MS Researcher, University of the Punjab
|
predictor.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# predictor.py
|
| 2 |
+
"""
|
| 3 |
+
FunGO β Prediction Engine
|
| 4 |
+
===========================
|
| 5 |
+
Loads XGBoost models once at startup.
|
| 6 |
+
Runs inference across all 3 ontologies (MFO, BPO, CCO).
|
| 7 |
+
|
| 8 |
+
Changes from original:
|
| 9 |
+
1. Added get_model_stats() β returns classifier counts per ontology
|
| 10 |
+
(used by /model/info endpoint).
|
| 11 |
+
2. Fixed open() to use context managers (file handles now closed).
|
| 12 |
+
3. tempfile.mktemp() replaced with NamedTemporaryFile (WSL fix).
|
| 13 |
+
4. Failed classifiers are counted and logged instead of silent pass.
|
| 14 |
+
5. Input shape validation in predict().
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
import pickle
|
| 20 |
+
import shutil
|
| 21 |
+
import subprocess
|
| 22 |
+
import tempfile
|
| 23 |
+
import numpy as np
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
from config import PKL_DIR, VOCAB_PKL, IA_PKL, FEAT_META
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
ONTS = ["MFO", "BPO", "CCO"]
|
| 30 |
+
|
| 31 |
+
# ββ Globals βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 32 |
+
_models_dict = None
|
| 33 |
+
_thresholds_dict = None
|
| 34 |
+
_ia_weights = None
|
| 35 |
+
_vocabularies = None
|
| 36 |
+
_top50_taxa = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 40 |
+
|
| 41 |
+
def _wsl_copy(src: Path) -> Path:
|
| 42 |
+
"""Copy file to temp path (WSL mounted-drive permission workaround)."""
|
| 43 |
+
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp:
|
| 44 |
+
tmp_path = Path(tmp.name)
|
| 45 |
+
shutil.copy2(str(src), str(tmp_path))
|
| 46 |
+
return tmp_path
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _safe_load(path: Path) -> object:
|
| 50 |
+
"""Load pickle with WSL permission workaround if needed."""
|
| 51 |
+
try:
|
| 52 |
+
subprocess.run(["chmod", "644", str(path)], check=False, capture_output=True)
|
| 53 |
+
except Exception:
|
| 54 |
+
pass
|
| 55 |
+
try:
|
| 56 |
+
with open(path, "rb") as fh:
|
| 57 |
+
return pickle.load(fh)
|
| 58 |
+
except PermissionError:
|
| 59 |
+
pass
|
| 60 |
+
tmp_path = None
|
| 61 |
+
try:
|
| 62 |
+
tmp_path = _wsl_copy(path)
|
| 63 |
+
with open(tmp_path, "rb") as fh:
|
| 64 |
+
return pickle.load(fh)
|
| 65 |
+
finally:
|
| 66 |
+
if tmp_path and tmp_path.exists():
|
| 67 |
+
tmp_path.unlink()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _safe_read_json(path: Path) -> dict:
|
| 71 |
+
"""Read JSON with WSL permission workaround."""
|
| 72 |
+
try:
|
| 73 |
+
subprocess.run(["chmod", "644", str(path)], check=False, capture_output=True)
|
| 74 |
+
except Exception:
|
| 75 |
+
pass
|
| 76 |
+
for mode in ("r", "rb"):
|
| 77 |
+
try:
|
| 78 |
+
with open(path, mode) as fh:
|
| 79 |
+
raw = fh.read()
|
| 80 |
+
if isinstance(raw, bytes):
|
| 81 |
+
raw = raw.decode("utf-8", errors="replace")
|
| 82 |
+
return json.loads(raw)
|
| 83 |
+
except PermissionError:
|
| 84 |
+
continue
|
| 85 |
+
result = subprocess.run(["cat", str(path)], capture_output=True, text=True, check=True)
|
| 86 |
+
return json.loads(result.stdout)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 90 |
+
|
| 91 |
+
def load_all():
|
| 92 |
+
"""
|
| 93 |
+
Load all models and supporting data into memory.
|
| 94 |
+
Call once at Flask startup (~30β120 s depending on hardware).
|
| 95 |
+
"""
|
| 96 |
+
global _models_dict, _thresholds_dict, _ia_weights, _vocabularies, _top50_taxa
|
| 97 |
+
|
| 98 |
+
logger.info("[predictor] Loading vocabularies β¦")
|
| 99 |
+
_vocabularies = _safe_load(VOCAB_PKL)
|
| 100 |
+
|
| 101 |
+
logger.info("[predictor] Loading IA weights β¦")
|
| 102 |
+
_ia_weights = _safe_load(IA_PKL)
|
| 103 |
+
logger.info("[predictor] IA weights: %d terms", len(_ia_weights))
|
| 104 |
+
|
| 105 |
+
meta = _safe_read_json(FEAT_META)
|
| 106 |
+
_top50_taxa = [int(t) for t in meta["taxonomy_info"]["top50_taxa"]]
|
| 107 |
+
logger.info("[predictor] Top-50 taxa loaded (%d)", len(_top50_taxa))
|
| 108 |
+
|
| 109 |
+
_models_dict = {}
|
| 110 |
+
_thresholds_dict = {}
|
| 111 |
+
|
| 112 |
+
for ont in ONTS:
|
| 113 |
+
pkl_path = PKL_DIR / f"models_{ont}.pkl"
|
| 114 |
+
size_mb = pkl_path.stat().st_size / 1e6
|
| 115 |
+
logger.info("[predictor] Loading %s (%.0f MB) β¦", pkl_path.name, size_mb)
|
| 116 |
+
|
| 117 |
+
raw = _safe_load(pkl_path)
|
| 118 |
+
first_key = next(iter(raw))
|
| 119 |
+
|
| 120 |
+
if first_key.startswith("GO:"):
|
| 121 |
+
models_d = raw
|
| 122 |
+
thresholds_d = {t: 0.5 for t in raw}
|
| 123 |
+
else:
|
| 124 |
+
clf_list = raw["models"]
|
| 125 |
+
term_list = raw["selected_terms"]
|
| 126 |
+
thr_raw = raw.get("thresholds", [0.5] * len(clf_list))
|
| 127 |
+
thr_list = (list(thr_raw) if not isinstance(thr_raw, dict)
|
| 128 |
+
else [thr_raw.get(t, 0.5) for t in term_list])
|
| 129 |
+
models_d = dict(zip(term_list, clf_list))
|
| 130 |
+
thresholds_d = dict(zip(term_list, thr_list))
|
| 131 |
+
|
| 132 |
+
_models_dict[ont] = models_d
|
| 133 |
+
_thresholds_dict[ont] = thresholds_d
|
| 134 |
+
logger.info("[predictor] %s: %d classifiers ready", ont, len(models_d))
|
| 135 |
+
|
| 136 |
+
logger.info("[predictor] All models loaded successfully.")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_top50_taxa() -> list:
|
| 140 |
+
if _top50_taxa is None:
|
| 141 |
+
raise RuntimeError("Models not loaded β call load_all() first.")
|
| 142 |
+
return _top50_taxa
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_ia_weights() -> dict:
|
| 146 |
+
if _ia_weights is None:
|
| 147 |
+
raise RuntimeError("Models not loaded β call load_all() first.")
|
| 148 |
+
return _ia_weights
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def get_model_stats() -> dict:
|
| 152 |
+
"""
|
| 153 |
+
Return classifier counts per ontology.
|
| 154 |
+
Used by GET /model/info endpoint.
|
| 155 |
+
Returns: {"MFO": 1500, "BPO": 1500, "CCO": 1133}
|
| 156 |
+
"""
|
| 157 |
+
if _models_dict is None:
|
| 158 |
+
raise RuntimeError("Models not loaded β call load_all() first.")
|
| 159 |
+
return {ont: len(models) for ont, models in _models_dict.items()}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def predict(X_final: np.ndarray, protein_ids: list) -> list:
|
| 163 |
+
"""
|
| 164 |
+
Run inference for all proteins across all 3 ontologies.
|
| 165 |
+
|
| 166 |
+
Parameters
|
| 167 |
+
----------
|
| 168 |
+
X_final : (N, 15411) float32 feature matrix
|
| 169 |
+
protein_ids : list of N protein ID strings
|
| 170 |
+
|
| 171 |
+
Returns
|
| 172 |
+
-------
|
| 173 |
+
List of raw prediction dicts:
|
| 174 |
+
[{protein_id, go_term, ontology, confidence, threshold}, β¦]
|
| 175 |
+
"""
|
| 176 |
+
if _models_dict is None:
|
| 177 |
+
raise RuntimeError("Models not loaded β call load_all() first.")
|
| 178 |
+
|
| 179 |
+
N = X_final.shape[0]
|
| 180 |
+
if N != len(protein_ids):
|
| 181 |
+
raise ValueError(
|
| 182 |
+
f"X_final has {N} rows but protein_ids has {len(protein_ids)} entries."
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
all_preds = []
|
| 186 |
+
failed_terms = 0
|
| 187 |
+
|
| 188 |
+
for ont in ONTS:
|
| 189 |
+
ont_models = _models_dict[ont]
|
| 190 |
+
ont_thresholds = _thresholds_dict[ont]
|
| 191 |
+
n_terms = len(ont_models)
|
| 192 |
+
logger.info("[predictor] %s β scoring %d terms Γ %d proteins β¦", ont, n_terms, N)
|
| 193 |
+
|
| 194 |
+
for go_term, clf in ont_models.items():
|
| 195 |
+
threshold = float(ont_thresholds.get(go_term, 0.5))
|
| 196 |
+
try:
|
| 197 |
+
proba = clf.predict_proba(X_final)[:, 1]
|
| 198 |
+
for i, pid in enumerate(protein_ids):
|
| 199 |
+
conf = float(proba[i])
|
| 200 |
+
if conf >= threshold:
|
| 201 |
+
all_preds.append({
|
| 202 |
+
"protein_id": pid,
|
| 203 |
+
"go_term": go_term,
|
| 204 |
+
"ontology": ont,
|
| 205 |
+
"confidence": round(conf, 4),
|
| 206 |
+
"threshold": round(threshold, 4),
|
| 207 |
+
})
|
| 208 |
+
except Exception as exc:
|
| 209 |
+
failed_terms += 1
|
| 210 |
+
logger.warning("[predictor] Classifier failed %s/%s: %s", ont, go_term, exc)
|
| 211 |
+
|
| 212 |
+
if failed_terms:
|
| 213 |
+
logger.warning("[predictor] Total failed classifiers: %d", failed_terms)
|
| 214 |
+
|
| 215 |
+
logger.info("[predictor] Inference complete β %d raw predictions", len(all_preds))
|
| 216 |
+
return all_preds
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flask>=3.0.0
|
| 2 |
+
flask-cors>=4.0.0
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
torch>=2.0.0
|
| 5 |
+
transformers>=4.35.0
|
| 6 |
+
xgboost>=2.0.0
|
| 7 |
+
requests>=2.31.0
|
| 8 |
+
gunicorn>=21.0.0
|
| 9 |
+
gradio>=4.0.0
|
taxonomy.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# taxonomy.py
|
| 2 |
+
"""
|
| 3 |
+
FunGO β NCBI Taxonomy Service
|
| 4 |
+
===============================
|
| 5 |
+
Species name β taxon ID lookup and reverse lookup.
|
| 6 |
+
|
| 7 |
+
Fixes applied:
|
| 8 |
+
1. UID string/int consistency β result_map keys are always strings,
|
| 9 |
+
now explicitly uses str(uid) so 9606 never resolves to {}.
|
| 10 |
+
2. Species-rank preference β results sorted so "species" rank
|
| 11 |
+
appears before "genus". Prevents 9605 (Homo genus) showing
|
| 12 |
+
before 9606 (Homo sapiens species).
|
| 13 |
+
3. Exact-name boost β exact query match moved to position 0.
|
| 14 |
+
4. Cache key includes max_results to prevent stale smaller lists.
|
| 15 |
+
5. xml.etree.ElementTree replaces fragile regex XML parsing.
|
| 16 |
+
6. Retry logic β 3 attempts with 2s gap on connection errors.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import time
|
| 21 |
+
import xml.etree.ElementTree as ET
|
| 22 |
+
import requests
|
| 23 |
+
|
| 24 |
+
from config import (
|
| 25 |
+
NCBI_SEARCH_URL, NCBI_SUMMARY_URL, NCBI_FETCH_URL,
|
| 26 |
+
NCBI_TOOL, NCBI_EMAIL,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
HEADERS = {"User-Agent": f"FunGO/1.0 ({NCBI_EMAIL})"}
|
| 31 |
+
TIMEOUT = 10
|
| 32 |
+
RETRIES = 3
|
| 33 |
+
RETRY_DELAY = 2
|
| 34 |
+
|
| 35 |
+
_RANK_PRIORITY = {
|
| 36 |
+
"species": 0, "subspecies": 1, "varietas": 2,
|
| 37 |
+
"forma": 3, "strain": 4, "no rank": 5,
|
| 38 |
+
"genus": 6, "family": 7, "order": 8,
|
| 39 |
+
"class": 9, "phylum": 10, "kingdom": 11, "superkingdom": 12,
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
def _rank_priority(rank: str) -> int:
|
| 43 |
+
return _RANK_PRIORITY.get(rank.lower().strip(), 99)
|
| 44 |
+
|
| 45 |
+
_search_cache: dict = {}
|
| 46 |
+
_id_to_info_cache: dict = {}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _ncbi_get(url: str, params: dict) -> requests.Response:
|
| 50 |
+
last_exc = None
|
| 51 |
+
for attempt in range(1, RETRIES + 1):
|
| 52 |
+
try:
|
| 53 |
+
resp = requests.get(url, params=params, timeout=TIMEOUT, headers=HEADERS)
|
| 54 |
+
resp.raise_for_status()
|
| 55 |
+
return resp
|
| 56 |
+
except requests.RequestException as exc:
|
| 57 |
+
last_exc = exc
|
| 58 |
+
if attempt < RETRIES:
|
| 59 |
+
logger.warning("[taxonomy] Request error (attempt %d/%d): %s β retrying in %ds",
|
| 60 |
+
attempt, RETRIES, exc, RETRY_DELAY)
|
| 61 |
+
time.sleep(RETRY_DELAY)
|
| 62 |
+
raise last_exc
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def search_species(query: str, max_results: int = 8) -> list:
|
| 66 |
+
"""
|
| 67 |
+
Search NCBI taxonomy by species name.
|
| 68 |
+
Returns [{taxon_id, scientific_name, common_name, rank, division}]
|
| 69 |
+
Sorted: species rank first, exact name match at position 0.
|
| 70 |
+
"""
|
| 71 |
+
query = query.strip()
|
| 72 |
+
if len(query) < 2:
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
cache_key = (query.lower(), max_results)
|
| 76 |
+
if cache_key in _search_cache:
|
| 77 |
+
return _search_cache[cache_key]
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
search_resp = _ncbi_get(NCBI_SEARCH_URL, {
|
| 81 |
+
"db": "taxonomy", "term": query,
|
| 82 |
+
"retmax": max_results, "retmode": "json",
|
| 83 |
+
"tool": NCBI_TOOL, "email": NCBI_EMAIL,
|
| 84 |
+
})
|
| 85 |
+
ids = search_resp.json().get("esearchresult", {}).get("idlist", [])
|
| 86 |
+
|
| 87 |
+
if not ids:
|
| 88 |
+
_search_cache[cache_key] = []
|
| 89 |
+
return []
|
| 90 |
+
|
| 91 |
+
summary_resp = _ncbi_get(NCBI_SUMMARY_URL, {
|
| 92 |
+
"db": "taxonomy", "id": ",".join(ids),
|
| 93 |
+
"retmode": "json", "tool": NCBI_TOOL, "email": NCBI_EMAIL,
|
| 94 |
+
})
|
| 95 |
+
result_map = summary_resp.json().get("result", {})
|
| 96 |
+
uids = result_map.get("uids", ids)
|
| 97 |
+
|
| 98 |
+
results = []
|
| 99 |
+
for uid in uids:
|
| 100 |
+
item = result_map.get(str(uid), {}) # FIX: explicit str()
|
| 101 |
+
if not item:
|
| 102 |
+
continue
|
| 103 |
+
results.append({
|
| 104 |
+
"taxon_id": int(uid),
|
| 105 |
+
"scientific_name": item.get("scientificname", ""),
|
| 106 |
+
"common_name": item.get("commonname", ""),
|
| 107 |
+
"rank": item.get("rank", ""),
|
| 108 |
+
"division": item.get("division", ""),
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
# FIX: sort by rank β species before genus
|
| 112 |
+
results.sort(key=lambda r: _rank_priority(r.get("rank", "")))
|
| 113 |
+
|
| 114 |
+
# FIX: exact name match β front of list
|
| 115 |
+
q_lower = query.lower()
|
| 116 |
+
exact = [r for r in results if r["scientific_name"].lower() == q_lower]
|
| 117 |
+
rest = [r for r in results if r["scientific_name"].lower() != q_lower]
|
| 118 |
+
results = exact + rest
|
| 119 |
+
|
| 120 |
+
_search_cache[cache_key] = results
|
| 121 |
+
logger.info("[taxonomy] search %r β %d results", query, len(results))
|
| 122 |
+
return results
|
| 123 |
+
|
| 124 |
+
except Exception as exc:
|
| 125 |
+
logger.error("[taxonomy] search_species(%r) failed: %s", query, exc)
|
| 126 |
+
return [{"error": str(exc)}]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_taxon_info(taxon_id: int) -> dict:
|
| 130 |
+
"""
|
| 131 |
+
Reverse lookup: taxon ID β full species info with lineage.
|
| 132 |
+
Uses xml.etree.ElementTree β handles multi-line XML correctly.
|
| 133 |
+
"""
|
| 134 |
+
if taxon_id in _id_to_info_cache:
|
| 135 |
+
return _id_to_info_cache[taxon_id]
|
| 136 |
+
|
| 137 |
+
base = {
|
| 138 |
+
"taxon_id": taxon_id, "scientific_name": "",
|
| 139 |
+
"common_name": "", "rank": "", "division": "",
|
| 140 |
+
"lineage": "", "verified": False,
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
try:
|
| 144 |
+
resp = _ncbi_get(NCBI_FETCH_URL, {
|
| 145 |
+
"db": "taxonomy", "id": taxon_id,
|
| 146 |
+
"retmode": "xml", "tool": NCBI_TOOL, "email": NCBI_EMAIL,
|
| 147 |
+
})
|
| 148 |
+
|
| 149 |
+
root = ET.fromstring(resp.text)
|
| 150 |
+
taxon_el = root.find("Taxon")
|
| 151 |
+
|
| 152 |
+
if taxon_el is None:
|
| 153 |
+
base["error"] = "Taxon element not found in NCBI XML"
|
| 154 |
+
return base
|
| 155 |
+
|
| 156 |
+
def txt(tag: str) -> str:
|
| 157 |
+
el = taxon_el.find(tag)
|
| 158 |
+
return (el.text or "").strip() if el is not None else ""
|
| 159 |
+
|
| 160 |
+
lineage_parts = [
|
| 161 |
+
(a.findtext("ScientificName") or "").strip()
|
| 162 |
+
for a in taxon_el.findall("./LineageEx/Taxon")
|
| 163 |
+
]
|
| 164 |
+
|
| 165 |
+
common = (taxon_el.findtext("OtherNames/CommonName") or
|
| 166 |
+
taxon_el.findtext("CommonName") or "")
|
| 167 |
+
|
| 168 |
+
info = {
|
| 169 |
+
**base,
|
| 170 |
+
"scientific_name": txt("ScientificName"),
|
| 171 |
+
"common_name": common.strip(),
|
| 172 |
+
"rank": txt("Rank"),
|
| 173 |
+
"division": txt("Division"),
|
| 174 |
+
"lineage": " > ".join(p for p in lineage_parts if p),
|
| 175 |
+
"verified": True,
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
_id_to_info_cache[taxon_id] = info
|
| 179 |
+
logger.info("[taxonomy] Resolved taxon %d β %s", taxon_id, info["scientific_name"])
|
| 180 |
+
return info
|
| 181 |
+
|
| 182 |
+
except ET.ParseError as exc:
|
| 183 |
+
logger.error("[taxonomy] XML parse error for taxon %d: %s", taxon_id, exc)
|
| 184 |
+
base["error"] = f"XML parse error: {exc}"
|
| 185 |
+
return base
|
| 186 |
+
except Exception as exc:
|
| 187 |
+
logger.error("[taxonomy] get_taxon_info(%d) failed: %s", taxon_id, exc)
|
| 188 |
+
base["error"] = str(exc)
|
| 189 |
+
return base
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def resolve_taxon(taxon_id: int, top50_taxa: list) -> dict:
|
| 193 |
+
"""Check training-set membership for a taxon ID."""
|
| 194 |
+
info = get_taxon_info(taxon_id)
|
| 195 |
+
in_training = taxon_id in top50_taxa
|
| 196 |
+
return {
|
| 197 |
+
**info,
|
| 198 |
+
"in_training": in_training,
|
| 199 |
+
"training_status": "in_training_data" if in_training else "unknown_species_fallback",
|
| 200 |
+
}
|