# GLOBAL WARNING SUPPRESSION import warnings warnings.filterwarnings("ignore") # IMPORTS import json import pickle import numpy as np import torch import io import csv from io import StringIO from typing import List, Dict from Bio import SeqIO from fastapi import FastAPI, Request, UploadFile, File from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, StreamingResponse from fastapi.templating import Jinja2Templates from transformers import AutoTokenizer, AutoModel # FASTAPI INIT app = FastAPI() # Static + templates app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") # MODEL LOADING DEVICE = torch.device("cpu") tokenizer = AutoTokenizer.from_pretrained( "facebook/esm2_t30_150M_UR50D" ) esm_model = AutoModel.from_pretrained( "facebook/esm2_t30_150M_UR50D" ).to(DEVICE) esm_model.eval() with open("model.pkl", "rb") as f: classifier = pickle.load(f) with open("label_map.json", "r") as f: LABEL_MAP = json.load(f) INV_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()} # ESM2 EMBEDDING def embed_sequence(seq: str) -> np.ndarray: seq = seq.strip() inputs = tokenizer( seq, return_tensors="pt", add_special_tokens=True, truncation=True ) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): outputs = esm_model(**inputs) token_emb = outputs.last_hidden_state.squeeze(0) mean_emb = token_emb[1:-1].mean(dim=0) return mean_emb.cpu().numpy().reshape(1, -1) # SINGLE SEQUENCE PREDICTION def run_single_prediction(seq: str): emb = embed_sequence(seq) probs = classifier.predict_proba(emb)[0] pred_class = int(np.argmax(probs)) pred_label = INV_LABEL_MAP[pred_class] return { "prediction_label": pred_label, "probabilities": { INV_LABEL_MAP[i]: float(p) for i, p in enumerate(probs) } } # FASTA PREDICTION def run_fasta_prediction(content: str): results = [] handle = StringIO(content) for record in SeqIO.parse(handle, "fasta"): seq = str(record.seq).strip() if not seq: continue emb = embed_sequence(seq) probs = classifier.predict_proba(emb)[0] pred_class = int(np.argmax(probs)) pred_label = INV_LABEL_MAP[pred_class] results.append({ "sequence": record.id, "length": len(seq), "prediction_label": pred_label, "probabilities": { INV_LABEL_MAP[i]: float(p) for i, p in enumerate(probs) } }) return {"results": results} # PAGE ROUTES @app.get("/", response_class=HTMLResponse) async def home(request: Request): return templates.TemplateResponse( "index.html", {"request": request} ) @app.get("/about", response_class=HTMLResponse) async def about(request: Request): return templates.TemplateResponse( "about.html", {"request": request} ) @app.get("/help", response_class=HTMLResponse) async def help_page(request: Request): return templates.TemplateResponse( "help.html", {"request": request} ) @app.get("/contact", response_class=HTMLResponse) async def contact(request: Request): return templates.TemplateResponse( "contact.html", {"request": request} ) # API: SINGLE SEQUENCE @app.post("/api/predict_sequence") async def api_predict_sequence(request: Request): # Try JSON try: data = await request.json() if "sequence" in data: return run_single_prediction(data["sequence"]) except Exception: pass # Try Form try: form = await request.form() if "sequence" in form: return run_single_prediction(form["sequence"]) except Exception: pass return {"error": "No sequence provided"} # API: FASTA FILE @app.post("/api/predict_fasta") async def api_predict_fasta(file: UploadFile = File(...)): raw = await file.read() content = raw.decode("utf-8", errors="ignore") return run_fasta_prediction(content) # API: DOWNLOAD CSV @app.post("/api/download_csv") async def download_csv(results: List[Dict]): if not results: return {"error": "No results to download"} output = io.StringIO() writer = csv.DictWriter(output, fieldnames=results[0].keys()) writer.writeheader() writer.writerows(results) output.seek(0) return StreamingResponse( output, media_type="text/csv", headers={ "Content-Disposition": "attachment; filename=canloc_results.csv" } )