Biocoder09 commited on
Commit
ccff2f3
·
verified ·
1 Parent(s): 10d9463

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -80
app.py CHANGED
@@ -1,163 +1,214 @@
1
- # GLOBAL WARNING SUPPRESSION
 
2
  import warnings
3
- warnings.filterwarnings("ignore", category=UserWarning)
4
- warnings.filterwarnings("ignore", category=FutureWarning)
5
- warnings.filterwarnings("ignore", category=RuntimeWarning)
6
- warnings.filterwarnings("ignore", category=DeprecationWarning)
7
- warnings.simplefilter("ignore")
8
 
9
- # IMPORTS
10
  import json
11
  import pickle
12
  import numpy as np
13
  import torch
14
- import io
 
15
  from io import StringIO
 
 
16
  from Bio import SeqIO
17
- import csv
18
 
19
  from fastapi import FastAPI, Request, UploadFile, File
20
  from fastapi.staticfiles import StaticFiles
21
- from fastapi.responses import HTMLResponse StreamingResponse
22
  from fastapi.templating import Jinja2Templates
23
 
24
  from transformers import AutoTokenizer, AutoModel
25
 
26
- # FASTAPI INIT + MOUNTS
 
 
27
  app = FastAPI()
28
 
 
29
  app.mount("/static", StaticFiles(directory="static"), name="static")
30
  templates = Jinja2Templates(directory="templates")
31
 
32
- # LOAD MODEL + TOKENIZER
 
 
33
  DEVICE = torch.device("cpu")
34
 
35
- tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")
36
- esm_model = AutoModel.from_pretrained("facebook/esm2_t30_150M_UR50D").to(DEVICE)
 
 
 
 
 
 
37
  esm_model.eval()
38
 
39
  with open("model.pkl", "rb") as f:
40
- classifier = pickle.load(f)
41
 
42
  with open("label_map.json", "r") as f:
43
- LABEL_MAP = json.load(f)
44
 
45
  INV_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
46
 
47
- # ESM2 EMBEDDING FUNCTION
48
 
49
- def embed_sequence(seq: str) -> np.ndarray:
50
- seq = seq.strip()
51
- inputs = tokenizer(seq, return_tensors="pt", add_special_tokens=True)
52
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
53
 
54
- with torch.no_grad():
55
- outputs = esm_model(**inputs)
56
 
57
- token_emb = outputs.last_hidden_state.squeeze(0)
58
- mean_emb = token_emb[1:-1].mean(dim=0)
59
- return mean_emb.cpu().numpy().reshape(1, -1)
 
 
 
60
 
61
- # PREDICT ONE SEQUENCE
62
 
63
- def run_single_prediction(seq: str):
64
- emb = embed_sequence(seq)
65
- probs = classifier.predict_proba(emb)[0]
66
- pred_class = int(np.argmax(probs))
67
- pred_label = INV_LABEL_MAP[pred_class]
68
 
69
- return {
70
- "prediction_label": pred_label,
71
- "probabilities": {INV_LABEL_MAP[i]: float(p) for i, p in enumerate(probs)}
72
- }
73
 
74
- # PREDICT FASTA FILE
75
 
76
- def run_fasta_prediction(content: str):
77
- results = []
78
- handle = StringIO(content)
79
 
80
- for record in SeqIO.parse(handle, "fasta"):
81
- seq = str(record.seq).strip()
82
- if not seq:
83
- continue
84
 
 
85
  emb = embed_sequence(seq)
86
  probs = classifier.predict_proba(emb)[0]
 
87
  pred_class = int(np.argmax(probs))
88
  pred_label = INV_LABEL_MAP[pred_class]
89
 
90
- results.append({
91
- "sequence": record.id,
92
- "length": len(seq),
93
- "prediction_label": pred_label,
94
- "probabilities": {INV_LABEL_MAP[i]: float(p) for i, p in enumerate(probs)}
95
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- return {"results": results}
 
98
 
99
- # PAGE ROUTES
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  @app.get("/", response_class=HTMLResponse)
102
  async def home(request: Request):
103
- return templates.TemplateResponse("index.html", {"request": request})
 
 
 
104
 
105
  @app.get("/about", response_class=HTMLResponse)
106
  async def about(request: Request):
107
- return templates.TemplateResponse("about.html", {"request": request})
 
 
 
108
 
109
  @app.get("/help", response_class=HTMLResponse)
110
  async def help_page(request: Request):
111
- return templates.TemplateResponse("help.html", {"request": request})
 
 
 
112
 
113
  @app.get("/contact", response_class=HTMLResponse)
114
  async def contact(request: Request):
115
- return templates.TemplateResponse("contact.html", {"request": request})
 
 
 
116
 
117
 
118
- # API: UNIVERSAL SEQUENCE PREDICTION
119
 
120
  @app.post("/api/predict_sequence")
121
  async def api_predict_sequence(request: Request):
122
- # 1. Try JSON
123
- try:
124
- data = await request.json()
125
- if "sequence" in data:
126
- return run_single_prediction(data["sequence"])
127
- except Exception:
128
- pass
129
 
130
- # 2. Try FormData
131
- try:
132
- form = await request.form()
133
- if "sequence" in form:
134
- return run_single_prediction(form["sequence"])
135
- except Exception:
136
- pass
137
 
138
- return {"error": "No sequence provided"}
139
 
140
 
141
- # API: FASTA PREDICTION
142
 
143
  @app.post("/api/predict_fasta")
144
  async def api_predict_fasta(file: UploadFile = File(...)):
145
- raw = await file.read()
146
- content = raw.decode("utf-8", errors="ignore")
147
- return run_fasta_prediction(content)
148
 
149
- # DOWNLOAD RESULT
 
150
 
151
  @app.post("/api/download_csv")
152
- async def download_csv(results: list[dict]):
 
 
 
153
  output = io.StringIO()
154
  writer = csv.DictWriter(output, fieldnames=results[0].keys())
155
  writer.writeheader()
156
  writer.writerows(results)
 
157
  output.seek(0)
158
 
159
  return StreamingResponse(
160
  output,
161
  media_type="text/csv",
162
- headers={"Content-Disposition": "attachment; filename=canloc_results.csv"}
163
- )
 
 
 
 
1
+ # GLOBAL WARNING SUPPRESSION
2
+
3
  import warnings
4
+ warnings.filterwarnings("ignore")
5
+
6
+ # IMPORTS
 
 
7
 
 
8
  import json
9
  import pickle
10
  import numpy as np
11
  import torch
12
+ import io
13
+ import csv
14
  from io import StringIO
15
+ from typing import List, Dict
16
+
17
  from Bio import SeqIO
 
18
 
19
  from fastapi import FastAPI, Request, UploadFile, File
20
  from fastapi.staticfiles import StaticFiles
21
+ from fastapi.responses import HTMLResponse, StreamingResponse
22
  from fastapi.templating import Jinja2Templates
23
 
24
  from transformers import AutoTokenizer, AutoModel
25
 
26
+
27
+ # FASTAPI INIT
28
+
29
  app = FastAPI()
30
 
31
+ # Static + templates
32
  app.mount("/static", StaticFiles(directory="static"), name="static")
33
  templates = Jinja2Templates(directory="templates")
34
 
35
+
36
+ # MODEL LOADING
37
+
38
  DEVICE = torch.device("cpu")
39
 
40
+ tokenizer = AutoTokenizer.from_pretrained(
41
+ "facebook/esm2_t30_150M_UR50D"
42
+ )
43
+
44
+ esm_model = AutoModel.from_pretrained(
45
+ "facebook/esm2_t30_150M_UR50D"
46
+ ).to(DEVICE)
47
+
48
  esm_model.eval()
49
 
50
  with open("model.pkl", "rb") as f:
51
+ classifier = pickle.load(f)
52
 
53
  with open("label_map.json", "r") as f:
54
+ LABEL_MAP = json.load(f)
55
 
56
  INV_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
57
 
 
58
 
59
+ # ESM2 EMBEDDING
 
 
 
60
 
61
+ def embed_sequence(seq: str) -> np.ndarray:
62
+ seq = seq.strip()
63
 
64
+ inputs = tokenizer(
65
+ seq,
66
+ return_tensors="pt",
67
+ add_special_tokens=True,
68
+ truncation=True
69
+ )
70
 
71
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
72
 
73
+ with torch.no_grad():
74
+ outputs = esm_model(**inputs)
 
 
 
75
 
76
+ token_emb = outputs.last_hidden_state.squeeze(0)
77
+ mean_emb = token_emb[1:-1].mean(dim=0)
 
 
78
 
79
+ return mean_emb.cpu().numpy().reshape(1, -1)
80
 
 
 
 
81
 
82
+ # SINGLE SEQUENCE PREDICTION
 
 
 
83
 
84
+ def run_single_prediction(seq: str):
85
  emb = embed_sequence(seq)
86
  probs = classifier.predict_proba(emb)[0]
87
+
88
  pred_class = int(np.argmax(probs))
89
  pred_label = INV_LABEL_MAP[pred_class]
90
 
91
+ return {
92
+ "prediction_label": pred_label,
93
+ "probabilities": {
94
+ INV_LABEL_MAP[i]: float(p)
95
+ for i, p in enumerate(probs)
96
+ }
97
+ }
98
+
99
+
100
+ # FASTA PREDICTION
101
+
102
+ def run_fasta_prediction(content: str):
103
+ results = []
104
+ handle = StringIO(content)
105
+
106
+ for record in SeqIO.parse(handle, "fasta"):
107
+ seq = str(record.seq).strip()
108
+ if not seq:
109
+ continue
110
 
111
+ emb = embed_sequence(seq)
112
+ probs = classifier.predict_proba(emb)[0]
113
 
114
+ pred_class = int(np.argmax(probs))
115
+ pred_label = INV_LABEL_MAP[pred_class]
116
+
117
+ results.append({
118
+ "sequence": record.id,
119
+ "length": len(seq),
120
+ "prediction_label": pred_label,
121
+ "probabilities": {
122
+ INV_LABEL_MAP[i]: float(p)
123
+ for i, p in enumerate(probs)
124
+ }
125
+ })
126
+
127
+ return {"results": results}
128
+
129
+
130
+ # PAGE ROUTES
131
 
132
  @app.get("/", response_class=HTMLResponse)
133
  async def home(request: Request):
134
+ return templates.TemplateResponse(
135
+ "index.html",
136
+ {"request": request}
137
+ )
138
 
139
  @app.get("/about", response_class=HTMLResponse)
140
  async def about(request: Request):
141
+ return templates.TemplateResponse(
142
+ "about.html",
143
+ {"request": request}
144
+ )
145
 
146
  @app.get("/help", response_class=HTMLResponse)
147
  async def help_page(request: Request):
148
+ return templates.TemplateResponse(
149
+ "help.html",
150
+ {"request": request}
151
+ )
152
 
153
  @app.get("/contact", response_class=HTMLResponse)
154
  async def contact(request: Request):
155
+ return templates.TemplateResponse(
156
+ "contact.html",
157
+ {"request": request}
158
+ )
159
 
160
 
161
+ # API: SINGLE SEQUENCE
162
 
163
  @app.post("/api/predict_sequence")
164
  async def api_predict_sequence(request: Request):
165
+ # Try JSON
166
+ try:
167
+ data = await request.json()
168
+ if "sequence" in data:
169
+ return run_single_prediction(data["sequence"])
170
+ except Exception:
171
+ pass
172
 
173
+ # Try Form
174
+ try:
175
+ form = await request.form()
176
+ if "sequence" in form:
177
+ return run_single_prediction(form["sequence"])
178
+ except Exception:
179
+ pass
180
 
181
+ return {"error": "No sequence provided"}
182
 
183
 
184
+ # API: FASTA FILE
185
 
186
  @app.post("/api/predict_fasta")
187
  async def api_predict_fasta(file: UploadFile = File(...)):
188
+ raw = await file.read()
189
+ content = raw.decode("utf-8", errors="ignore")
190
+ return run_fasta_prediction(content)
191
 
192
+
193
+ # API: DOWNLOAD CSV
194
 
195
  @app.post("/api/download_csv")
196
+ async def download_csv(results: List[Dict]):
197
+ if not results:
198
+ return {"error": "No results to download"}
199
+
200
  output = io.StringIO()
201
  writer = csv.DictWriter(output, fieldnames=results[0].keys())
202
  writer.writeheader()
203
  writer.writerows(results)
204
+
205
  output.seek(0)
206
 
207
  return StreamingResponse(
208
  output,
209
  media_type="text/csv",
210
+ headers={
211
+ "Content-Disposition":
212
+ "attachment; filename=canloc_results.csv"
213
+ }
214
+ )