Biocoder09 commited on
Commit
b5af915
·
verified ·
1 Parent(s): beb0c16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -21
app.py CHANGED
@@ -1,6 +1,4 @@
1
- # ================================
2
  # GLOBAL WARNING SUPPRESSION
3
- # ================================
4
  import warnings
5
  warnings.filterwarnings("ignore", category=UserWarning)
6
  warnings.filterwarnings("ignore", category=FutureWarning)
@@ -8,34 +6,30 @@ warnings.filterwarnings("ignore", category=RuntimeWarning)
8
  warnings.filterwarnings("ignore", category=DeprecationWarning)
9
  warnings.simplefilter("ignore")
10
 
11
- # ================================
12
  # IMPORTS
13
- # ================================
14
  import json
15
  import pickle
16
  import numpy as np
17
  import torch
 
18
  from io import StringIO
19
  from Bio import SeqIO
 
20
 
21
  from fastapi import FastAPI, Request, UploadFile, File
22
  from fastapi.staticfiles import StaticFiles
23
- from fastapi.responses import HTMLResponse
24
  from fastapi.templating import Jinja2Templates
25
 
26
  from transformers import AutoTokenizer, AutoModel
27
 
28
- # ================================
29
  # FASTAPI INIT + MOUNTS
30
- # ================================
31
  app = FastAPI()
32
 
33
  app.mount("/static", StaticFiles(directory="static"), name="static")
34
  templates = Jinja2Templates(directory="templates")
35
 
36
- # ================================
37
  # LOAD MODEL + TOKENIZER
38
- # ================================
39
  DEVICE = torch.device("cpu")
40
 
41
  tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")
@@ -50,9 +44,8 @@ with open("label_map.json", "r") as f:
50
 
51
  INV_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
52
 
53
- # ================================
54
  # ESM2 EMBEDDING FUNCTION
55
- # ================================
56
  def embed_sequence(seq: str) -> np.ndarray:
57
  seq = seq.strip()
58
  inputs = tokenizer(seq, return_tensors="pt", add_special_tokens=True)
@@ -65,9 +58,8 @@ def embed_sequence(seq: str) -> np.ndarray:
65
  mean_emb = token_emb[1:-1].mean(dim=0)
66
  return mean_emb.cpu().numpy().reshape(1, -1)
67
 
68
- # ================================
69
  # PREDICT ONE SEQUENCE
70
- # ================================
71
  def run_single_prediction(seq: str):
72
  emb = embed_sequence(seq)
73
  probs = classifier.predict_proba(emb)[0]
@@ -79,9 +71,8 @@ def run_single_prediction(seq: str):
79
  "probabilities": {INV_LABEL_MAP[i]: float(p) for i, p in enumerate(probs)}
80
  }
81
 
82
- # ================================
83
  # PREDICT FASTA FILE
84
- # ================================
85
  def run_fasta_prediction(content: str):
86
  results = []
87
  handle = StringIO(content)
@@ -105,9 +96,8 @@ def run_fasta_prediction(content: str):
105
 
106
  return {"results": results}
107
 
108
- # ================================
109
  # PAGE ROUTES
110
- # ================================
111
  @app.get("/", response_class=HTMLResponse)
112
  async def home(request: Request):
113
  return templates.TemplateResponse("index.html", {"request": request})
@@ -124,9 +114,9 @@ async def help_page(request: Request):
124
  async def contact(request: Request):
125
  return templates.TemplateResponse("contact.html", {"request": request})
126
 
127
- # ================================
128
  # API: UNIVERSAL SEQUENCE PREDICTION
129
- # ================================
130
  @app.post("/api/predict_sequence")
131
  async def api_predict_sequence(request: Request):
132
  # 1. Try JSON
@@ -147,11 +137,27 @@ async def api_predict_sequence(request: Request):
147
 
148
  return {"error": "No sequence provided"}
149
 
150
- # ================================
151
  # API: FASTA PREDICTION
152
- # ================================
153
  @app.post("/api/predict_fasta")
154
  async def api_predict_fasta(file: UploadFile = File(...)):
155
  raw = await file.read()
156
  content = raw.decode("utf-8", errors="ignore")
157
  return run_fasta_prediction(content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # GLOBAL WARNING SUPPRESSION
 
2
  import warnings
3
  warnings.filterwarnings("ignore", category=UserWarning)
4
  warnings.filterwarnings("ignore", category=FutureWarning)
 
6
  warnings.filterwarnings("ignore", category=DeprecationWarning)
7
  warnings.simplefilter("ignore")
8
 
 
9
  # IMPORTS
 
10
  import json
11
  import pickle
12
  import numpy as np
13
  import torch
14
+ import io
15
  from io import StringIO
16
  from Bio import SeqIO
17
+ import csv
18
 
19
  from fastapi import FastAPI, Request, UploadFile, File
20
  from fastapi.staticfiles import StaticFiles
21
+ from fastapi.responses import HTMLResponse StreamingResponse
22
  from fastapi.templating import Jinja2Templates
23
 
24
  from transformers import AutoTokenizer, AutoModel
25
 
 
26
  # FASTAPI INIT + MOUNTS
 
27
  app = FastAPI()
28
 
29
  app.mount("/static", StaticFiles(directory="static"), name="static")
30
  templates = Jinja2Templates(directory="templates")
31
 
 
32
  # LOAD MODEL + TOKENIZER
 
33
  DEVICE = torch.device("cpu")
34
 
35
  tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")
 
44
 
45
  INV_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
46
 
 
47
  # ESM2 EMBEDDING FUNCTION
48
+
49
  def embed_sequence(seq: str) -> np.ndarray:
50
  seq = seq.strip()
51
  inputs = tokenizer(seq, return_tensors="pt", add_special_tokens=True)
 
58
  mean_emb = token_emb[1:-1].mean(dim=0)
59
  return mean_emb.cpu().numpy().reshape(1, -1)
60
 
 
61
  # PREDICT ONE SEQUENCE
62
+
63
  def run_single_prediction(seq: str):
64
  emb = embed_sequence(seq)
65
  probs = classifier.predict_proba(emb)[0]
 
71
  "probabilities": {INV_LABEL_MAP[i]: float(p) for i, p in enumerate(probs)}
72
  }
73
 
 
74
  # PREDICT FASTA FILE
75
+
76
  def run_fasta_prediction(content: str):
77
  results = []
78
  handle = StringIO(content)
 
96
 
97
  return {"results": results}
98
 
 
99
  # PAGE ROUTES
100
+
101
  @app.get("/", response_class=HTMLResponse)
102
  async def home(request: Request):
103
  return templates.TemplateResponse("index.html", {"request": request})
 
114
  async def contact(request: Request):
115
  return templates.TemplateResponse("contact.html", {"request": request})
116
 
117
+
118
  # API: UNIVERSAL SEQUENCE PREDICTION
119
+
120
  @app.post("/api/predict_sequence")
121
  async def api_predict_sequence(request: Request):
122
  # 1. Try JSON
 
137
 
138
  return {"error": "No sequence provided"}
139
 
140
+
141
  # API: FASTA PREDICTION
142
+
143
  @app.post("/api/predict_fasta")
144
  async def api_predict_fasta(file: UploadFile = File(...)):
145
  raw = await file.read()
146
  content = raw.decode("utf-8", errors="ignore")
147
  return run_fasta_prediction(content)
148
+
149
+ # DOWNLOAD RESULT
150
+
151
+ @app.post("/api/download_csv")
152
+ async def download_csv(results: list[dict]):
153
+ output = io.StringIO()
154
+ writer = csv.DictWriter(output, fieldnames=results[0].keys())
155
+ writer.writeheader()
156
+ writer.writerows(results)
157
+ output.seek(0)
158
+
159
+ return StreamingResponse(
160
+ output,
161
+ media_type="text/csv",
162
+ headers={"Content-Disposition": "attachment; filename=canloc_results.csv"}
163
+ )