FunGO / app.py
Muteeba's picture
fix: explicitly disable HF_HUB_OFFLINE for model download
1162af1
# app.py β€” FunGO HuggingFace Space v2
import csv, io, logging, os, re as _re, sys, time
from collections import OrderedDict
os.environ.setdefault("FUNGO_PKL_DIR", "/tmp/models")
os.environ.setdefault("FUNGO_VOCAB_PKL", "/app/data/labels/vocabularies.pkl")
os.environ.setdefault("FUNGO_IA_PKL", "/app/data/go_data/ia_weights.pkl")
os.environ.setdefault("FUNGO_FEAT_META", "/app/data/features/feature_metadata.json")
os.environ.setdefault("FUNGO_MODEL_CACHE","/tmp/esm2_cache")
os.environ.setdefault("FUNGO_EMB_CACHE", "/tmp/embedding_cache")
os.environ.setdefault("FUNGO_OFFLINE", "0")
os.environ["HF_HUB_OFFLINE"] = "0"
os.environ.setdefault("HUGGING_FACE_HUB_TOKEN", os.environ.get("HF_TOKEN", ""))
os.environ.setdefault("FUNGO_DEBUG", "0")
os.environ.setdefault("FUNGO_PORT", "7860")
from flask import Flask, jsonify, request, Response
from flask_cors import CORS
import config
import predictor
import embedder
import filter as flt
import taxonomy
logging.basicConfig(level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s β€” %(message)s", datefmt="%H:%M:%S")
log = logging.getLogger("fungo.app")
app = Flask(__name__)
CORS(app)
app.config["MAX_CONTENT_LENGTH"] = 2 * 1024 * 1024
_csv_store: OrderedDict = OrderedDict()
_CSV_MAX = 50
_models_ready = False
# ── Download models from HF Model Repo ───────────────────────
def download_models_if_needed():
"""Download pkl files from Muteeba/FunGO-models if not present."""
models_dir = "/tmp/models"
os.makedirs(models_dir, exist_ok=True)
files_needed = ["models_BPO.pkl", "models_MFO.pkl", "models_CCO.pkl"]
all_present = all(
os.path.exists(os.path.join(models_dir, f)) for f in files_needed
)
if all_present:
log.info("Model files already present in /data/models/ β€” skipping download")
return True
log.info("Downloading model files from Muteeba/FunGO-models ...")
try:
from huggingface_hub import hf_hub_download
for fname in files_needed:
dest = os.path.join(models_dir, fname)
if os.path.exists(dest):
log.info(" %s already exists β€” skip", fname)
continue
log.info(" Downloading %s ...", fname)
hf_hub_download(
token=os.environ.get("HF_TOKEN"),
repo_id="Muteeba/FunGO-models",
filename=fname,
repo_type="model",
local_dir=models_dir,
)
log.info(" %s done!", fname)
log.info("All model files downloaded!")
return True
except Exception as e:
log.error("Model download failed: %s", e)
return False
# ── CSV helpers ───────────────────────────────────────────────
def _store_csv(job_id, predictions):
if len(_csv_store) >= _CSV_MAX: _csv_store.popitem(last=False)
_csv_store[job_id] = {"predictions": predictions, "ts": time.time()}
def _make_csv(predictions):
out = io.StringIO()
w = csv.writer(out)
w.writerow(["protein_id","go_term","ontology","ontology_label",
"tier","tier_label","confidence","ia_weight","combined_score","threshold"])
for pid, data in predictions.items():
for p in data.get("all", []):
w.writerow([pid,p.get("go_term",""),p.get("ontology",""),
p.get("ontology_label",""),p.get("tier",""),p.get("tier_label",""),
p.get("confidence",""),p.get("ia_weight",""),
p.get("combined_score",""),p.get("threshold","")])
return out.getvalue()
_OX_RE = _re.compile(r"OX=(\d+)")
def _parse_taxon_id(header):
m = _OX_RE.search(header or "")
return int(m.group(1)) if m else None
def parse_fasta(fasta_text):
proteins, current_id, current_hdr, current_seq = [], None, None, []
for raw_line in fasta_text.splitlines():
line = raw_line.strip()
if not line: continue
if line.startswith(">"):
if current_id is not None:
seq = "".join(current_seq).upper()
if seq:
proteins.append({"id":current_id,"seq":seq,
"header":current_hdr,"taxon_id":_parse_taxon_id(current_hdr)})
current_hdr = line[1:].strip()
parts = current_hdr.split("|")
current_id = parts[1] if len(parts)>=3 else current_hdr.split()[0]
current_seq = []
else: current_seq.append(line)
if current_id is not None:
seq = "".join(current_seq).upper()
if seq:
proteins.append({"id":current_id,"seq":seq,
"header":current_hdr,"taxon_id":_parse_taxon_id(current_hdr)})
if not proteins: raise ValueError("No valid protein sequences found.")
return proteins
def _run_prediction(fasta_text, taxon_id_override):
proteins = parse_fasta(fasta_text)
if len(proteins) > config.MAX_SEQUENCES:
raise ValueError(f"Too many sequences. Max: {config.MAX_SEQUENCES}.")
protein_ids = [p["id"] for p in proteins]
sequences = [p["seq"] for p in proteins]
taxon_ids = [taxon_id_override if taxon_id_override is not None
else p["taxon_id"] for p in proteins]
t0 = time.perf_counter()
X_esm = embedder.extract(sequences)
top50 = predictor.get_top50_taxa()
X_final = embedder.build_features(X_esm, taxon_ids, top50)
raw_preds = predictor.predict(X_final, protein_ids)
ia_weights = predictor.get_ia_weights()
for p in raw_preds:
p["ia_weight"] = round(float(ia_weights.get(p["go_term"],0.0)),4)
return proteins, raw_preds, ia_weights, round(time.perf_counter()-t0,2)
# ── Routes ────────────────────────────────────────────────────
@app.route("/", methods=["GET"])
def index():
return jsonify({"name":"FunGO API","version":"2.0.0",
"status":"running","models_ready":_models_ready,
"endpoints":["/health","/model/info","/taxonomy/search",
"/taxonomy/verify","/predict","/predict/csv","/predict/debug"]})
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status":"ok","device":config.DEVICE,
"fp16":config.USE_FP16,"version":"2.0.0","models_ready":_models_ready})
@app.route("/model/info", methods=["GET"])
def model_info():
if not _models_ready:
return jsonify({"error":"Models not loaded yet."}), 503
try: stats = predictor.get_model_stats()
except RuntimeError as e: return jsonify({"error":str(e)}), 503
return jsonify({"device":config.DEVICE,"fp16":config.USE_FP16,
"model_name":config.MODEL_NAME,"ontologies":stats,
"top50_taxa_count":len(predictor.get_top50_taxa()),
"thresholds":{
"STRONG": {"min_ia":config.TIER_GOLD_IA, "min_conf":config.TIER_GOLD_CONF},
"MODERATE": {"min_ia":config.TIER_GOOD_IA, "min_conf":config.TIER_GOOD_CONF},
"INDICATIVE":{"min_ia":config.TIER_SILVER_IA,"min_conf":config.TIER_SILVER_CONF},
},"display_limit":flt.TOP_N_DISPLAY})
@app.route("/taxonomy/search", methods=["GET"])
def taxonomy_search():
q = request.args.get("q","").strip()
if len(q)<2: return jsonify({"error":"Query must be at least 2 characters."}), 400
try: max_r = min(int(request.args.get("max_results",8)),20)
except: max_r = 8
return jsonify({"query":q,"results":taxonomy.search_species(q,max_results=max_r)})
@app.route("/taxonomy/verify", methods=["GET"])
def taxonomy_verify():
raw = request.args.get("taxon_id","")
if not raw: return jsonify({"error":"taxon_id required."}), 400
try: taxon_id = int(raw)
except: return jsonify({"error":f"Invalid taxon_id: '{raw}'"}), 400
return jsonify(taxonomy.resolve_taxon(taxon_id, predictor.get_top50_taxa()))
@app.route("/predict", methods=["POST"])
def predict():
if not _models_ready:
return jsonify({"error":"Models not loaded yet."}), 503
if not request.is_json: return jsonify({"error":"Content-Type must be application/json."}), 415
body = request.get_json(silent=True) or {}
fasta_text = body.get("fasta","").strip()
if not fasta_text: return jsonify({"error":"'fasta' field is required."}), 400
taxon_id_override = None
if "taxon_id" in body:
try: taxon_id_override = int(body["taxon_id"])
except: return jsonify({"error":"Invalid taxon_id"}), 400
try:
proteins, raw_preds, ia_weights, elapsed = _run_prediction(fasta_text, taxon_id_override)
except ValueError as e: return jsonify({"error":str(e)}), 400
except RuntimeError as e: return jsonify({"error":str(e)}), 503
except Exception as e:
log.exception("Prediction error"); return jsonify({"error":str(e)}), 500
protein_ids = [p["id"] for p in proteins]
raw_by_pid = {pid:[] for pid in protein_ids}
for pred in raw_preds: raw_by_pid[pred["protein_id"]].append(pred)
predictions, csv_data, total_display, total_all = {},{},0,0
for prot in proteins:
pid = prot["id"]
res = flt.filter_predictions(raw_by_pid[pid], ia_weights)
display, all_f = res["display"], res["all"]
total_display += len(display); total_all += len(all_f)
predictions[pid] = {"taxon_id":prot["taxon_id"],
"summary":flt.summarise(display,all_f,pid),
"display":display,"total_all":len(all_f)}
csv_data[pid] = {"all":all_f}
job_id = str(int(time.time()*1000))
_store_csv(job_id, csv_data)
return jsonify({"job_id":job_id,
"metadata":{"n_proteins":len(protein_ids),"device":config.DEVICE,
"total_raw_predictions":len(raw_preds),"total_filtered":total_all,
"total_displayed":total_display,"display_limit":flt.TOP_N_DISPLAY,
"elapsed_seconds":elapsed},
"predictions":predictions})
@app.route("/predict/csv", methods=["GET"])
def download_csv():
job_id = request.args.get("job_id","").strip()
if not job_id: return jsonify({"error":"job_id required."}), 400
job = _csv_store.get(job_id)
if not job: return jsonify({"error":f"Job '{job_id}' not found."}), 404
return Response(_make_csv(job["predictions"]),mimetype="text/csv",
headers={"Content-Disposition":f"attachment; filename=fungo_{job_id}.csv"})
@app.route("/predict/debug", methods=["POST"])
def predict_debug():
if not _models_ready:
return jsonify({"error":"Models not loaded."}), 503
if not request.is_json: return jsonify({"error":"Content-Type must be application/json."}), 415
body = request.get_json(silent=True) or {}
fasta_text = body.get("fasta","").strip()
if not fasta_text: return jsonify({"error":"'fasta' required."}), 400
taxon_id_override = None
if "taxon_id" in body:
try: taxon_id_override = int(body["taxon_id"])
except: return jsonify({"error":"Invalid taxon_id"}), 400
try:
proteins, raw_preds, ia_weights, elapsed = _run_prediction(fasta_text, taxon_id_override)
except Exception as e:
log.exception("Debug error"); return jsonify({"error":str(e)}), 500
protein_ids = [p["id"] for p in proteins]
raw_by_pid = {pid:[] for pid in protein_ids}
for pred in raw_preds: raw_by_pid[pred["protein_id"]].append(pred)
thr = {"STRONG":{"min_ia":config.TIER_GOLD_IA,"min_conf":config.TIER_GOLD_CONF},
"MODERATE":{"min_ia":config.TIER_GOOD_IA,"min_conf":config.TIER_GOOD_CONF},
"INDICATIVE":{"min_ia":config.TIER_SILVER_IA,"min_conf":config.TIER_SILVER_CONF}}
predictions = {}
for prot in proteins:
pid = prot["id"]
res = flt.filter_predictions(raw_by_pid[pid], ia_weights)
display, all_f = res["display"], res["all"]
accepted = {p["go_term"] for p in all_f}
fo = []
for pred in raw_by_pid[pid]:
go = pred["go_term"]
if go in accepted: continue
ia,conf = pred.get("ia_weight",float(ia_weights.get(go,0.0))),pred["confidence"]
if go in config.BLACKLIST_TERMS: reason="blacklisted"
elif ia<=config.TIER_SILVER_IA: reason=f"ia_too_low (ia={ia:.4f})"
elif conf<config.TIER_SILVER_CONF: reason=f"conf_too_low (conf={conf:.4f})"
else: reason="below_all_tiers"
fo.append({"go_term":go,"ontology":pred["ontology"],
"confidence":conf,"ia_weight":ia,"reason":reason})
fo.sort(key=lambda x:-x["ia_weight"])
predictions[pid] = {"taxon_id":prot["taxon_id"],
"summary":flt.summarise(display,all_f,pid),
"display":display,"all_filtered":all_f,
"filtered_out":fo,"thresholds_used":thr}
return jsonify({"metadata":{"n_proteins":len(protein_ids),"device":config.DEVICE,
"total_raw":len(raw_preds),"elapsed_seconds":elapsed},"predictions":predictions})
@app.errorhandler(404)
def not_found(e): return jsonify({"error":"Not found."}), 404
@app.errorhandler(413)
def too_large(e): return jsonify({"error":"Request too large."}), 413
@app.errorhandler(500)
def internal(e):
log.exception("Unhandled error"); return jsonify({"error":"Internal server error."}), 500
if __name__ == "__main__":
log.info("FunGO v2.0 β€” HuggingFace Space starting ...")
config.ensure_dirs()
# Download models if not present
download_models_if_needed()
# Load models
paths_ok = config.validate_paths()
if paths_ok:
try:
predictor.load_all()
_models_ready = True
log.info("Models loaded successfully!")
except Exception as e:
log.error("Model loading failed: %s", e)
else:
log.warning("Some paths missing β€” predictions disabled")
log.info("Serving on port 7860 ...")
app.run(host="0.0.0.0", port=7860, debug=False)
# ── Module-level startup (gunicorn compatible) ────────────────
log.info("FunGO v2.0 starting ...")
config.ensure_dirs()
download_models_if_needed()
paths_ok = config.validate_paths()
if paths_ok:
try:
predictor.load_all()
_models_ready = True
log.info("Models loaded successfully!")
except Exception as e:
log.error("Model loading failed: %s", e)
else:
log.warning("Some paths missing β€” predictions disabled")