Spaces:
Sleeping
Sleeping
| # GLOBAL WARNING SUPPRESSION | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # IMPORTS | |
| import json | |
| import pickle | |
| import numpy as np | |
| import torch | |
| import io | |
| import csv | |
| from io import StringIO | |
| from typing import List, Dict | |
| from Bio import SeqIO | |
| from fastapi import FastAPI, Request, UploadFile, File | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse, StreamingResponse | |
| from fastapi.templating import Jinja2Templates | |
| from transformers import AutoTokenizer, AutoModel | |
| # FASTAPI INIT | |
| app = FastAPI() | |
| # Static + templates | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| # MODEL LOADING | |
| DEVICE = torch.device("cpu") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "facebook/esm2_t30_150M_UR50D" | |
| ) | |
| esm_model = AutoModel.from_pretrained( | |
| "facebook/esm2_t30_150M_UR50D" | |
| ).to(DEVICE) | |
| esm_model.eval() | |
| with open("model.pkl", "rb") as f: | |
| classifier = pickle.load(f) | |
| with open("label_map.json", "r") as f: | |
| LABEL_MAP = json.load(f) | |
| INV_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()} | |
| # ESM2 EMBEDDING | |
| def embed_sequence(seq: str) -> np.ndarray: | |
| seq = seq.strip() | |
| inputs = tokenizer( | |
| seq, | |
| return_tensors="pt", | |
| add_special_tokens=True, | |
| truncation=True | |
| ) | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = esm_model(**inputs) | |
| token_emb = outputs.last_hidden_state.squeeze(0) | |
| mean_emb = token_emb[1:-1].mean(dim=0) | |
| return mean_emb.cpu().numpy().reshape(1, -1) | |
| # SINGLE SEQUENCE PREDICTION | |
| def run_single_prediction(seq: str): | |
| emb = embed_sequence(seq) | |
| probs = classifier.predict_proba(emb)[0] | |
| pred_class = int(np.argmax(probs)) | |
| pred_label = INV_LABEL_MAP[pred_class] | |
| return { | |
| "prediction_label": pred_label, | |
| "probabilities": { | |
| INV_LABEL_MAP[i]: float(p) | |
| for i, p in enumerate(probs) | |
| } | |
| } | |
| # FASTA PREDICTION | |
| def run_fasta_prediction(content: str): | |
| results = [] | |
| handle = StringIO(content) | |
| for record in SeqIO.parse(handle, "fasta"): | |
| seq = str(record.seq).strip() | |
| if not seq: | |
| continue | |
| emb = embed_sequence(seq) | |
| probs = classifier.predict_proba(emb)[0] | |
| pred_class = int(np.argmax(probs)) | |
| pred_label = INV_LABEL_MAP[pred_class] | |
| results.append({ | |
| "sequence": record.id, | |
| "length": len(seq), | |
| "prediction_label": pred_label, | |
| "probabilities": { | |
| INV_LABEL_MAP[i]: float(p) | |
| for i, p in enumerate(probs) | |
| } | |
| }) | |
| return {"results": results} | |
| # PAGE ROUTES | |
| async def home(request: Request): | |
| return templates.TemplateResponse( | |
| "index.html", | |
| {"request": request} | |
| ) | |
| async def about(request: Request): | |
| return templates.TemplateResponse( | |
| "about.html", | |
| {"request": request} | |
| ) | |
| async def help_page(request: Request): | |
| return templates.TemplateResponse( | |
| "help.html", | |
| {"request": request} | |
| ) | |
| async def contact(request: Request): | |
| return templates.TemplateResponse( | |
| "contact.html", | |
| {"request": request} | |
| ) | |
| # API: SINGLE SEQUENCE | |
| async def api_predict_sequence(request: Request): | |
| # Try JSON | |
| try: | |
| data = await request.json() | |
| if "sequence" in data: | |
| return run_single_prediction(data["sequence"]) | |
| except Exception: | |
| pass | |
| # Try Form | |
| try: | |
| form = await request.form() | |
| if "sequence" in form: | |
| return run_single_prediction(form["sequence"]) | |
| except Exception: | |
| pass | |
| return {"error": "No sequence provided"} | |
| # API: FASTA FILE | |
| async def api_predict_fasta(file: UploadFile = File(...)): | |
| raw = await file.read() | |
| content = raw.decode("utf-8", errors="ignore") | |
| return run_fasta_prediction(content) | |
| # API: DOWNLOAD CSV | |
| async def download_csv(results: List[Dict]): | |
| if not results: | |
| return {"error": "No results to download"} | |
| output = io.StringIO() | |
| writer = csv.DictWriter(output, fieldnames=results[0].keys()) | |
| writer.writeheader() | |
| writer.writerows(results) | |
| output.seek(0) | |
| return StreamingResponse( | |
| output, | |
| media_type="text/csv", | |
| headers={ | |
| "Content-Disposition": | |
| "attachment; filename=canloc_results.csv" | |
| } | |
| ) | |