eloise54 commited on
Commit
3ad7268
·
1 Parent(s): 24b7fe6
Files changed (2) hide show
  1. Dockerfile +3 -5
  2. app.py +376 -0
Dockerfile CHANGED
@@ -3,13 +3,11 @@ FROM gcr.io/kaggle-gpu-images/python@sha256:7a9d2a6b13b3566aa6cc5a22447f3b3f99ac
3
  RUN apt-get update && apt-get install -y git-lfs
4
  RUN git lfs install
5
 
6
- RUN git clone https://gitlab.com/nn_projects/cafa6_project.git /root/cafa6_project
7
-
8
- WORKDIR /root/cafa6_project
9
-
10
  EXPOSE 7860
11
 
12
  ENV GRADIO_SERVER_NAME=0.0.0.0
13
  ENV GRADIO_SERVER_PORT=7860
14
 
15
- CMD ["python", "app.py"]
 
 
 
3
  RUN apt-get update && apt-get install -y git-lfs
4
  RUN git lfs install
5
 
 
 
 
 
6
  EXPOSE 7860
7
 
8
  ENV GRADIO_SERVER_NAME=0.0.0.0
9
  ENV GRADIO_SERVER_PORT=7860
10
 
11
+ COPY app.py /root/app.py
12
+
13
+ CMD ["python", "/root/app.py"]
app.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ############################################ INSTALL PACKAGES ############################################
2
+
3
+ import sys
4
+ import subprocess
5
+
6
+ def install(package):
7
+ # Add --upgrade to force install the latest version
8
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", package])
9
+
10
+ install("gradio>=3.44.0")
11
+ install("biopython==1.86")
12
+ install("cachetools==5.4.0")
13
+ install("mlflow==3.7.0")
14
+
15
+
16
+
17
+ HF_REPO_URL = "https://gitlab.com/nn_projects/cafa6_project"
18
+ CLONE_DIR = "/root/cafa6_project"
19
+
20
+ if not os.path.exists(CLONE_DIR):
21
+ os.system(f"git clone {HF_REPO_URL} {CLONE_DIR}")
22
+
23
+ os.chdir(CLONE_DIR)
24
+ ############################################ DEFINE CONSTANTS ############################################
25
+ import os
26
+ import gc
27
+ import pandas as pd
28
+ import numpy as np
29
+ from collections import defaultdict
30
+ from tqdm.auto import tqdm
31
+ import mlflow
32
+ import torch
33
+ import random
34
+ import requests
35
+ import re
36
+ from transformers import set_seed
37
+ from torch.utils.data import Dataset, DataLoader
38
+ from Bio import SeqIO
39
+ import gradio as gr
40
+
41
+ input_path = './'
42
+ data_dir = "numpy_dataset/"
43
+ test_embeddings_data = "prot_t5_embeddings_right_pooling_False_test_mini"
44
+
45
+ test_batch_size = 64
46
+ SEED = 42
47
+ MAX_SEQ_LEN = 512
48
+ HIDDEN_DIM = 1024
49
+ THRESH = 0.003
50
+
51
+ random.seed(SEED)
52
+ np.random.seed(SEED)
53
+ torch.manual_seed(SEED)
54
+ torch.cuda.manual_seed(SEED)
55
+ set_seed(SEED)
56
+
57
+
58
+ ############################################ LOAD MODELS ############################################
59
+
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ print("USING DEVICE: ", device)
62
+
63
+ run_id = '9ee8f63638d0494ea20b63710c19a8b3'
64
+ SUBMISSION_INPUT_PATH = input_path + 'mlruns/11/' + run_id + '/artifacts/'
65
+ SUBMISSION_INPUT = SUBMISSION_INPUT_PATH+ 'submission.tsv'
66
+ OUTPUT_PATH = SUBMISSION_INPUT_PATH + '/diamond/'
67
+ models_uri = np.load(SUBMISSION_INPUT_PATH + "models_uri_C_fold_4.npy")
68
+ mlb_arrays_uri = np.load(SUBMISSION_INPUT_PATH + "mlb_arrays_uri_C_fold_4.npy")
69
+
70
+ #LOAD MODELS
71
+ MODELS = []
72
+ for uri in models_uri:
73
+ model = mlflow.pytorch.load_model(uri)
74
+ model.eval()
75
+ model.to(device)
76
+ MODELS.append(model)
77
+
78
+ #LOAD ONE HOT MLB ARRAYS
79
+ MLB_ARRAYS = [
80
+ np.load(uri, allow_pickle=True)
81
+ for uri in mlb_arrays_uri
82
+ ]
83
+
84
+ #CREATE MATRIX ADAPTATORS TO MAP UNIQUE GO IDS TO EACH MODEL'S MLB ARRAY PREDICTION
85
+ concatenated_array = np.concatenate(MLB_ARRAYS)
86
+ unique_go_ids = np.array(list(set(concatenated_array)))
87
+ print(concatenated_array.shape)
88
+ print(unique_go_ids.shape)
89
+
90
+ matrix_adaptators = []
91
+ for n in range(0, len(MLB_ARRAYS)):
92
+ mlb_array = MLB_ARRAYS[n] # shape: (num_labels,)
93
+ unique_prot_ids_array = unique_go_ids # shape: (num_proteins,)
94
+ prob_matrix_adaptator = torch.zeros(len(unique_go_ids), len(mlb_array))
95
+
96
+ for i in range(0, len(unique_go_ids)):
97
+ for j in range(0, len(mlb_array)):
98
+ if unique_go_ids[i] == mlb_array[j]:
99
+ prob_matrix_adaptator[i, j] = 1.0
100
+ print(n, " " , prob_matrix_adaptator.shape)
101
+ matrix_adaptators.append(prob_matrix_adaptator.to(device))
102
+
103
+ ############################################ UNIPROTKB PROT5 EMBEDDINGS ############################################
104
+
105
+ from transformers import T5Tokenizer, T5EncoderModel
106
+ model_type = "Rostlab/prot_t5_xl_uniref50"
107
+ tokenizer = T5Tokenizer.from_pretrained(model_type, do_lower_case=False, truncation_side = "right") #do not put to lower case, prot T5 needs upper case letters
108
+ protT5 = T5EncoderModel.from_pretrained(model_type, trust_remote_code=True).to(device)
109
+ max_sequence_len = 512 #prot_t5_xl pretraining done on max 512 seq len
110
+ batch_size = 64
111
+ SPECIAL_IDS = set(tokenizer.all_special_ids)
112
+ # Freeze params, inference only
113
+ protT5.eval()
114
+ for param in protT5.parameters():
115
+ param.requires_grad = False
116
+
117
+ def preprocess_sequence(seq: str) -> str:
118
+ seq = " ".join(seq)
119
+ seq = re.sub(r"[UZOB]", "X", seq)
120
+ return seq
121
+
122
+
123
+ def fetch_uniprot_sequence(uniprot_id: str) -> str:
124
+ url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta"
125
+ r = requests.get(url, timeout=10)
126
+ if r.status_code != 200:
127
+ raise ValueError(f"UniProt ID '{uniprot_id}' not found")
128
+
129
+ fasta = r.text.splitlines()
130
+ return "".join(line for line in fasta if not line.startswith(">"))
131
+
132
+
133
+ def generate_embedding_from_uniprot(uniprot_id: str):
134
+ seq = fetch_uniprot_sequence(uniprot_id)
135
+ seq = preprocess_sequence(seq)
136
+
137
+ tokens = tokenizer(
138
+ seq,
139
+ return_tensors="pt",
140
+ truncation=True,
141
+ add_special_tokens=True,
142
+ padding="max_length",
143
+ max_length=MAX_SEQ_LEN,
144
+ )
145
+ tokens = {k: v.to(device) for k, v in tokens.items()}
146
+
147
+ with torch.no_grad():
148
+ outputs = protT5(**tokens)
149
+
150
+ raw_embeddings = outputs.last_hidden_state # (1, L, D)
151
+
152
+ input_ids = tokens["input_ids"]
153
+ mask_2d = tokens["attention_mask"].clone() # (1, L)
154
+
155
+ for sid in SPECIAL_IDS:
156
+ mask_2d[input_ids == sid] = 0
157
+
158
+ mask_3d = mask_2d.unsqueeze(-1).float() # (1, L, 1)
159
+ masked_embeddings = raw_embeddings * mask_3d # (1, L, D)
160
+
161
+ return masked_embeddings, mask_2d
162
+
163
+
164
+ ############################################ PREDICTION CODE ############################################
165
+
166
+ def ensemble_predict(embedding, mask, topk=20):
167
+ scores = torch.zeros(1, len(unique_go_ids), device=device)
168
+ counts = torch.zeros_like(scores)
169
+
170
+ for model, adaptor in zip(MODELS, matrix_adaptators):
171
+ preds = model(embedding, mask) # (num_go, 1)
172
+ preds = preds.transpose(0, 1) # (1, num_go)
173
+ adapted = adaptor @ preds # (unique_go, 1)
174
+ adapted = adapted.T # (1, unique_go)
175
+
176
+ scores += adapted
177
+ counts += (torch.abs(adapted) > 1e-9).float()
178
+
179
+ scores /= torch.clamp(counts, min=1)
180
+ scores = scores.squeeze(0)
181
+
182
+ mask = scores > THRESH
183
+ scores = scores * mask.float()
184
+
185
+ idx = torch.argsort(scores, descending=True)[:topk]
186
+
187
+ return pd.DataFrame({
188
+ "GO_ID": unique_go_ids[idx.cpu().numpy()],
189
+ "Score": scores[idx].round(decimals=3).detach().cpu().numpy()
190
+ })
191
+
192
+
193
+ def predict(uniprot_id, topk):
194
+ embedding, mask = generate_embedding_from_uniprot(uniprot_id)
195
+ return ensemble_predict(embedding, mask, topk)
196
+
197
+
198
+ ############################################ UNIPROTKB AND GRADIO UTILS ############################################
199
+
200
+ def fetch_human_uniprot_examples(n=500):
201
+ url = "https://rest.uniprot.org/uniprotkb/search"
202
+ params = {
203
+ "query": "organism_id:9606 AND reviewed:true",
204
+ "fields": "accession",
205
+ "format": "json",
206
+ "size": n
207
+ }
208
+
209
+ r = requests.get(url, params=params, timeout=10)
210
+ r.raise_for_status()
211
+
212
+ data = r.json()
213
+ return [e["primaryAccession"] for e in data["results"]]
214
+
215
+
216
+ def fetch_quickgo_annotations(uniprot_id, aspects=None):
217
+ if aspects is None:
218
+ aspects = ["biological_process", "molecular_function", "cellular_component"]
219
+
220
+ go_ids = set()
221
+ pageSize = 200 # max allowed by API
222
+
223
+ for aspect in aspects:
224
+ page = 1
225
+ while True:
226
+ url = (
227
+ f"https://www.ebi.ac.uk/QuickGO/services/annotation/search?"
228
+ f"geneProductId=UniProtKB:{uniprot_id}&limit={pageSize}&page={page}&aspect={aspect}"
229
+ )
230
+ try:
231
+ response = requests.get(url, headers={"Accept": "application/json"}, timeout=10)
232
+ response.raise_for_status()
233
+ except Exception as e:
234
+ print(f"Warning: failed to fetch {aspect} annotations: {e}")
235
+ break
236
+
237
+ data = response.json()
238
+ results = data.get("results", [])
239
+ if not results:
240
+ break
241
+
242
+ for item in results:
243
+ go_id = item.get("goId")
244
+ if go_id:
245
+ go_ids.add(go_id)
246
+
247
+ page += 1
248
+
249
+ return go_ids
250
+
251
+
252
+ def color_topk_predictions(pred_df, true_go_ids):
253
+ colors = []
254
+ for go_id in pred_df["GO_ID"]:
255
+ if go_id in true_go_ids:
256
+ colors.append("background-color: #d4edda") # green
257
+ else:
258
+ colors.append("background-color: #f8d7da") # red
259
+ pred_df["Color"] = colors
260
+ return pred_df
261
+
262
+ def predictions_to_html(pred_df):
263
+ html = "<div style='text-align: center;'>"
264
+ html += "<table border='1' style='border-collapse: collapse; margin: 0 auto;'>"
265
+ html += "<tr><th>GO_ID</th><th>Score</th></tr>"
266
+
267
+ for _, row in pred_df.iterrows():
268
+ color = row['Color']
269
+ html += f"<tr style='{color}'>"
270
+ html += f"<td>{row['GO_ID']}</td>"
271
+ html += f"<td>{row['Score']}</td>"
272
+ html += "</tr>"
273
+
274
+ html += "</table></div>"
275
+ return html
276
+
277
+ ############################################ GRADIO APP ############################################
278
+
279
+ markdown_information= r"""
280
+ ## Trained on CAFA6 Protein Function Prediction Dataset
281
+ ```
282
+ @misc{cafa-6-protein-function-prediction,
283
+ author = {Iddo Friedberg and Predrag Radivojac and Paul D Thomas and An Phan and M. Clara De Paolis Kaluza and Damiano Piovesan and Parnal Joshi and Chris Mungall and Martyna Plomecka and Walter Reade and María Cruz},
284
+ title = {CAFA 6 Protein Function Prediction},
285
+ year = {2025},
286
+ howpublished = {\url{https://kaggle.com/competitions/cafa-6-protein-function-prediction}},
287
+ note = {Kaggle}
288
+ }
289
+ ```
290
+
291
+ ## SPROF-GO
292
+
293
+ ```
294
+ @article{10.1093/bib/bbad117,
295
+ author = {Yuan, Qianmu and Xie, Junjie and Xie, Jiancong and Zhao, Huiying and Yang, Yuedong},
296
+ title = "{Fast and accurate protein function prediction from sequence through pretrained language model and homology-based label diffusion}",
297
+ journal = {Briefings in Bioinformatics},
298
+ year = {2023},
299
+ month = {03},
300
+ issn = {1477-4054},
301
+ doi = {10.1093/bib/bbad117},
302
+ url = {https://doi.org/10.1093/bib/bbad117}
303
+ }
304
+ ```
305
+ """
306
+
307
+ def predict_with_uniprot_highlight_html(uniprot_id, topk):
308
+ embedding, mask = generate_embedding_from_uniprot(uniprot_id)
309
+ pred_df = ensemble_predict(embedding, mask, topk)
310
+ true_go_ids = fetch_quickgo_annotations(uniprot_id)
311
+ colored_df = color_topk_predictions(pred_df, true_go_ids)
312
+ return predictions_to_html(colored_df)
313
+
314
+ HUMAN_EXAMPLES = fetch_human_uniprot_examples()
315
+
316
+ def format_human_examples_md(examples):
317
+ md = ""
318
+ md += " ".join(
319
+ f"[{acc}](https://www.uniprot.org/uniprotkb/{acc}) /"
320
+ for acc in examples
321
+ )
322
+ return md
323
+
324
+ def fetch_human_examples_md():
325
+ examples = HUMAN_EXAMPLES.copy()
326
+ random.shuffle(examples)
327
+ return format_human_examples_md(examples[:40])
328
+
329
+
330
+
331
+ with gr.Blocks() as demo:
332
+ gr.Markdown("# 🧬 [SPROF-GO](https://github.com/biomed-AI/SPROF-GO) Ensemble Trained on [CAFA6](https://www.kaggle.com/competitions/cafa-6-protein-function-prediction/overview)")
333
+
334
+ # ===================== Inputs =====================
335
+ with gr.Row(equal_height=True):
336
+ with gr.Column(scale=1):
337
+ gr.Markdown("## Inference")
338
+ gr.Markdown("⚠️ No label diffusion for fast inference.")
339
+ uniprot_input = gr.Textbox(label="UniProtKB Protein ID", value="O75594")
340
+ topk_slider = gr.Slider(5, 50, value=10, step=5, label="Top-K GO terms")
341
+ run_btn = gr.Button("Predict")
342
+
343
+ with gr.Column(scale=1):
344
+ gr.Markdown("## Human Protein Examples")
345
+ human_examples_md_comp = gr.Markdown(format_human_examples_md(HUMAN_EXAMPLES[:50]))
346
+ example_btn = gr.Button("🔄 Fetch more examples")
347
+ example_btn.click(fetch_human_examples_md, outputs=human_examples_md_comp)
348
+
349
+ # ===================== Output =====================
350
+ gr.HTML("<hr style='margin:20px 0;'>") # horizontal divider
351
+
352
+ gr.Markdown("### Prediction Table")
353
+ html_output = gr.HTML(label="Predicted GO terms")
354
+
355
+ gr.Markdown(
356
+ "Top-K predictions colored green if predicted GO term is in "
357
+ "[QuickGO](https://www.ebi.ac.uk/QuickGO/annotations) annotations, "
358
+ "red otherwise"
359
+ )
360
+
361
+ gr.HTML("<hr style='margin:20px 0;'>") # horizontal divider
362
+
363
+ gr.Markdown(markdown_information)
364
+
365
+ run_btn.click(predict_with_uniprot_highlight_html,
366
+ inputs=[uniprot_input, topk_slider],
367
+ outputs=[html_output])
368
+
369
+ demo.load(
370
+ predict_with_uniprot_highlight_html,
371
+ inputs=[uniprot_input, topk_slider],
372
+ outputs=[html_output]
373
+ )
374
+
375
+ if __name__ == "__main__":
376
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)