# app.py — FunGO HuggingFace Space v2 import csv, io, logging, os, re as _re, sys, time from collections import OrderedDict os.environ.setdefault("FUNGO_PKL_DIR", "/tmp/models") os.environ.setdefault("FUNGO_VOCAB_PKL", "/app/data/labels/vocabularies.pkl") os.environ.setdefault("FUNGO_IA_PKL", "/app/data/go_data/ia_weights.pkl") os.environ.setdefault("FUNGO_FEAT_META", "/app/data/features/feature_metadata.json") os.environ.setdefault("FUNGO_MODEL_CACHE","/tmp/esm2_cache") os.environ.setdefault("FUNGO_EMB_CACHE", "/tmp/embedding_cache") os.environ.setdefault("FUNGO_OFFLINE", "0") os.environ["HF_HUB_OFFLINE"] = "0" os.environ.setdefault("HUGGING_FACE_HUB_TOKEN", os.environ.get("HF_TOKEN", "")) os.environ.setdefault("FUNGO_DEBUG", "0") os.environ.setdefault("FUNGO_PORT", "7860") from flask import Flask, jsonify, request, Response from flask_cors import CORS import config import predictor import embedder import filter as flt import taxonomy logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s — %(message)s", datefmt="%H:%M:%S") log = logging.getLogger("fungo.app") app = Flask(__name__) CORS(app) app.config["MAX_CONTENT_LENGTH"] = 2 * 1024 * 1024 _csv_store: OrderedDict = OrderedDict() _CSV_MAX = 50 _models_ready = False # ── Download models from HF Model Repo ─────────────────────── def download_models_if_needed(): """Download pkl files from Muteeba/FunGO-models if not present.""" models_dir = "/tmp/models" os.makedirs(models_dir, exist_ok=True) files_needed = ["models_BPO.pkl", "models_MFO.pkl", "models_CCO.pkl"] all_present = all( os.path.exists(os.path.join(models_dir, f)) for f in files_needed ) if all_present: log.info("Model files already present in /data/models/ — skipping download") return True log.info("Downloading model files from Muteeba/FunGO-models ...") try: from huggingface_hub import hf_hub_download for fname in files_needed: dest = os.path.join(models_dir, fname) if os.path.exists(dest): log.info(" %s already exists — skip", fname) continue log.info(" Downloading %s ...", fname) hf_hub_download( token=os.environ.get("HF_TOKEN"), repo_id="Muteeba/FunGO-models", filename=fname, repo_type="model", local_dir=models_dir, ) log.info(" %s done!", fname) log.info("All model files downloaded!") return True except Exception as e: log.error("Model download failed: %s", e) return False # ── CSV helpers ─────────────────────────────────────────────── def _store_csv(job_id, predictions): if len(_csv_store) >= _CSV_MAX: _csv_store.popitem(last=False) _csv_store[job_id] = {"predictions": predictions, "ts": time.time()} def _make_csv(predictions): out = io.StringIO() w = csv.writer(out) w.writerow(["protein_id","go_term","ontology","ontology_label", "tier","tier_label","confidence","ia_weight","combined_score","threshold"]) for pid, data in predictions.items(): for p in data.get("all", []): w.writerow([pid,p.get("go_term",""),p.get("ontology",""), p.get("ontology_label",""),p.get("tier",""),p.get("tier_label",""), p.get("confidence",""),p.get("ia_weight",""), p.get("combined_score",""),p.get("threshold","")]) return out.getvalue() _OX_RE = _re.compile(r"OX=(\d+)") def _parse_taxon_id(header): m = _OX_RE.search(header or "") return int(m.group(1)) if m else None def parse_fasta(fasta_text): proteins, current_id, current_hdr, current_seq = [], None, None, [] for raw_line in fasta_text.splitlines(): line = raw_line.strip() if not line: continue if line.startswith(">"): if current_id is not None: seq = "".join(current_seq).upper() if seq: proteins.append({"id":current_id,"seq":seq, "header":current_hdr,"taxon_id":_parse_taxon_id(current_hdr)}) current_hdr = line[1:].strip() parts = current_hdr.split("|") current_id = parts[1] if len(parts)>=3 else current_hdr.split()[0] current_seq = [] else: current_seq.append(line) if current_id is not None: seq = "".join(current_seq).upper() if seq: proteins.append({"id":current_id,"seq":seq, "header":current_hdr,"taxon_id":_parse_taxon_id(current_hdr)}) if not proteins: raise ValueError("No valid protein sequences found.") return proteins def _run_prediction(fasta_text, taxon_id_override): proteins = parse_fasta(fasta_text) if len(proteins) > config.MAX_SEQUENCES: raise ValueError(f"Too many sequences. Max: {config.MAX_SEQUENCES}.") protein_ids = [p["id"] for p in proteins] sequences = [p["seq"] for p in proteins] taxon_ids = [taxon_id_override if taxon_id_override is not None else p["taxon_id"] for p in proteins] t0 = time.perf_counter() X_esm = embedder.extract(sequences) top50 = predictor.get_top50_taxa() X_final = embedder.build_features(X_esm, taxon_ids, top50) raw_preds = predictor.predict(X_final, protein_ids) ia_weights = predictor.get_ia_weights() for p in raw_preds: p["ia_weight"] = round(float(ia_weights.get(p["go_term"],0.0)),4) return proteins, raw_preds, ia_weights, round(time.perf_counter()-t0,2) # ── Routes ──────────────────────────────────────────────────── @app.route("/", methods=["GET"]) def index(): return jsonify({"name":"FunGO API","version":"2.0.0", "status":"running","models_ready":_models_ready, "endpoints":["/health","/model/info","/taxonomy/search", "/taxonomy/verify","/predict","/predict/csv","/predict/debug"]}) @app.route("/health", methods=["GET"]) def health(): return jsonify({"status":"ok","device":config.DEVICE, "fp16":config.USE_FP16,"version":"2.0.0","models_ready":_models_ready}) @app.route("/model/info", methods=["GET"]) def model_info(): if not _models_ready: return jsonify({"error":"Models not loaded yet."}), 503 try: stats = predictor.get_model_stats() except RuntimeError as e: return jsonify({"error":str(e)}), 503 return jsonify({"device":config.DEVICE,"fp16":config.USE_FP16, "model_name":config.MODEL_NAME,"ontologies":stats, "top50_taxa_count":len(predictor.get_top50_taxa()), "thresholds":{ "STRONG": {"min_ia":config.TIER_GOLD_IA, "min_conf":config.TIER_GOLD_CONF}, "MODERATE": {"min_ia":config.TIER_GOOD_IA, "min_conf":config.TIER_GOOD_CONF}, "INDICATIVE":{"min_ia":config.TIER_SILVER_IA,"min_conf":config.TIER_SILVER_CONF}, },"display_limit":flt.TOP_N_DISPLAY}) @app.route("/taxonomy/search", methods=["GET"]) def taxonomy_search(): q = request.args.get("q","").strip() if len(q)<2: return jsonify({"error":"Query must be at least 2 characters."}), 400 try: max_r = min(int(request.args.get("max_results",8)),20) except: max_r = 8 return jsonify({"query":q,"results":taxonomy.search_species(q,max_results=max_r)}) @app.route("/taxonomy/verify", methods=["GET"]) def taxonomy_verify(): raw = request.args.get("taxon_id","") if not raw: return jsonify({"error":"taxon_id required."}), 400 try: taxon_id = int(raw) except: return jsonify({"error":f"Invalid taxon_id: '{raw}'"}), 400 return jsonify(taxonomy.resolve_taxon(taxon_id, predictor.get_top50_taxa())) @app.route("/predict", methods=["POST"]) def predict(): if not _models_ready: return jsonify({"error":"Models not loaded yet."}), 503 if not request.is_json: return jsonify({"error":"Content-Type must be application/json."}), 415 body = request.get_json(silent=True) or {} fasta_text = body.get("fasta","").strip() if not fasta_text: return jsonify({"error":"'fasta' field is required."}), 400 taxon_id_override = None if "taxon_id" in body: try: taxon_id_override = int(body["taxon_id"]) except: return jsonify({"error":"Invalid taxon_id"}), 400 try: proteins, raw_preds, ia_weights, elapsed = _run_prediction(fasta_text, taxon_id_override) except ValueError as e: return jsonify({"error":str(e)}), 400 except RuntimeError as e: return jsonify({"error":str(e)}), 503 except Exception as e: log.exception("Prediction error"); return jsonify({"error":str(e)}), 500 protein_ids = [p["id"] for p in proteins] raw_by_pid = {pid:[] for pid in protein_ids} for pred in raw_preds: raw_by_pid[pred["protein_id"]].append(pred) predictions, csv_data, total_display, total_all = {},{},0,0 for prot in proteins: pid = prot["id"] res = flt.filter_predictions(raw_by_pid[pid], ia_weights) display, all_f = res["display"], res["all"] total_display += len(display); total_all += len(all_f) predictions[pid] = {"taxon_id":prot["taxon_id"], "summary":flt.summarise(display,all_f,pid), "display":display,"total_all":len(all_f)} csv_data[pid] = {"all":all_f} job_id = str(int(time.time()*1000)) _store_csv(job_id, csv_data) return jsonify({"job_id":job_id, "metadata":{"n_proteins":len(protein_ids),"device":config.DEVICE, "total_raw_predictions":len(raw_preds),"total_filtered":total_all, "total_displayed":total_display,"display_limit":flt.TOP_N_DISPLAY, "elapsed_seconds":elapsed}, "predictions":predictions}) @app.route("/predict/csv", methods=["GET"]) def download_csv(): job_id = request.args.get("job_id","").strip() if not job_id: return jsonify({"error":"job_id required."}), 400 job = _csv_store.get(job_id) if not job: return jsonify({"error":f"Job '{job_id}' not found."}), 404 return Response(_make_csv(job["predictions"]),mimetype="text/csv", headers={"Content-Disposition":f"attachment; filename=fungo_{job_id}.csv"}) @app.route("/predict/debug", methods=["POST"]) def predict_debug(): if not _models_ready: return jsonify({"error":"Models not loaded."}), 503 if not request.is_json: return jsonify({"error":"Content-Type must be application/json."}), 415 body = request.get_json(silent=True) or {} fasta_text = body.get("fasta","").strip() if not fasta_text: return jsonify({"error":"'fasta' required."}), 400 taxon_id_override = None if "taxon_id" in body: try: taxon_id_override = int(body["taxon_id"]) except: return jsonify({"error":"Invalid taxon_id"}), 400 try: proteins, raw_preds, ia_weights, elapsed = _run_prediction(fasta_text, taxon_id_override) except Exception as e: log.exception("Debug error"); return jsonify({"error":str(e)}), 500 protein_ids = [p["id"] for p in proteins] raw_by_pid = {pid:[] for pid in protein_ids} for pred in raw_preds: raw_by_pid[pred["protein_id"]].append(pred) thr = {"STRONG":{"min_ia":config.TIER_GOLD_IA,"min_conf":config.TIER_GOLD_CONF}, "MODERATE":{"min_ia":config.TIER_GOOD_IA,"min_conf":config.TIER_GOOD_CONF}, "INDICATIVE":{"min_ia":config.TIER_SILVER_IA,"min_conf":config.TIER_SILVER_CONF}} predictions = {} for prot in proteins: pid = prot["id"] res = flt.filter_predictions(raw_by_pid[pid], ia_weights) display, all_f = res["display"], res["all"] accepted = {p["go_term"] for p in all_f} fo = [] for pred in raw_by_pid[pid]: go = pred["go_term"] if go in accepted: continue ia,conf = pred.get("ia_weight",float(ia_weights.get(go,0.0))),pred["confidence"] if go in config.BLACKLIST_TERMS: reason="blacklisted" elif ia<=config.TIER_SILVER_IA: reason=f"ia_too_low (ia={ia:.4f})" elif conf