canloc / app.py
Biocoder09's picture
Update app.py
ccff2f3 verified
raw
history blame
4.77 kB
# GLOBAL WARNING SUPPRESSION
import warnings
warnings.filterwarnings("ignore")
# IMPORTS
import json
import pickle
import numpy as np
import torch
import io
import csv
from io import StringIO
from typing import List, Dict
from Bio import SeqIO
from fastapi import FastAPI, Request, UploadFile, File
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.templating import Jinja2Templates
from transformers import AutoTokenizer, AutoModel
# FASTAPI INIT
app = FastAPI()
# Static + templates
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# MODEL LOADING
DEVICE = torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained(
"facebook/esm2_t30_150M_UR50D"
)
esm_model = AutoModel.from_pretrained(
"facebook/esm2_t30_150M_UR50D"
).to(DEVICE)
esm_model.eval()
with open("model.pkl", "rb") as f:
classifier = pickle.load(f)
with open("label_map.json", "r") as f:
LABEL_MAP = json.load(f)
INV_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
# ESM2 EMBEDDING
def embed_sequence(seq: str) -> np.ndarray:
seq = seq.strip()
inputs = tokenizer(
seq,
return_tensors="pt",
add_special_tokens=True,
truncation=True
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
outputs = esm_model(**inputs)
token_emb = outputs.last_hidden_state.squeeze(0)
mean_emb = token_emb[1:-1].mean(dim=0)
return mean_emb.cpu().numpy().reshape(1, -1)
# SINGLE SEQUENCE PREDICTION
def run_single_prediction(seq: str):
emb = embed_sequence(seq)
probs = classifier.predict_proba(emb)[0]
pred_class = int(np.argmax(probs))
pred_label = INV_LABEL_MAP[pred_class]
return {
"prediction_label": pred_label,
"probabilities": {
INV_LABEL_MAP[i]: float(p)
for i, p in enumerate(probs)
}
}
# FASTA PREDICTION
def run_fasta_prediction(content: str):
results = []
handle = StringIO(content)
for record in SeqIO.parse(handle, "fasta"):
seq = str(record.seq).strip()
if not seq:
continue
emb = embed_sequence(seq)
probs = classifier.predict_proba(emb)[0]
pred_class = int(np.argmax(probs))
pred_label = INV_LABEL_MAP[pred_class]
results.append({
"sequence": record.id,
"length": len(seq),
"prediction_label": pred_label,
"probabilities": {
INV_LABEL_MAP[i]: float(p)
for i, p in enumerate(probs)
}
})
return {"results": results}
# PAGE ROUTES
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse(
"index.html",
{"request": request}
)
@app.get("/about", response_class=HTMLResponse)
async def about(request: Request):
return templates.TemplateResponse(
"about.html",
{"request": request}
)
@app.get("/help", response_class=HTMLResponse)
async def help_page(request: Request):
return templates.TemplateResponse(
"help.html",
{"request": request}
)
@app.get("/contact", response_class=HTMLResponse)
async def contact(request: Request):
return templates.TemplateResponse(
"contact.html",
{"request": request}
)
# API: SINGLE SEQUENCE
@app.post("/api/predict_sequence")
async def api_predict_sequence(request: Request):
# Try JSON
try:
data = await request.json()
if "sequence" in data:
return run_single_prediction(data["sequence"])
except Exception:
pass
# Try Form
try:
form = await request.form()
if "sequence" in form:
return run_single_prediction(form["sequence"])
except Exception:
pass
return {"error": "No sequence provided"}
# API: FASTA FILE
@app.post("/api/predict_fasta")
async def api_predict_fasta(file: UploadFile = File(...)):
raw = await file.read()
content = raw.decode("utf-8", errors="ignore")
return run_fasta_prediction(content)
# API: DOWNLOAD CSV
@app.post("/api/download_csv")
async def download_csv(results: List[Dict]):
if not results:
return {"error": "No results to download"}
output = io.StringIO()
writer = csv.DictWriter(output, fieldnames=results[0].keys())
writer.writeheader()
writer.writerows(results)
output.seek(0)
return StreamingResponse(
output,
media_type="text/csv",
headers={
"Content-Disposition":
"attachment; filename=canloc_results.csv"
}
)