wangjin2000 commited on
Commit
2334c10
·
verified ·
1 Parent(s): d17119b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +383 -0
app.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ref: https://huggingface.co/blog/AmelieSchreiber/esmbind
2
+ import gradio as gr
3
+
4
+ import os
5
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
6
+ #import wandb
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import pickle
11
+ import xml.etree.ElementTree as ET
12
+ from datetime import datetime
13
+ from sklearn.model_selection import train_test_split
14
+ from sklearn.utils.class_weight import compute_class_weight
15
+ from sklearn.metrics import (
16
+ accuracy_score,
17
+ precision_recall_fscore_support,
18
+ roc_auc_score,
19
+ matthews_corrcoef
20
+ )
21
+ from transformers import (
22
+ AutoModelForTokenClassification,
23
+ AutoTokenizer,
24
+ DataCollatorForTokenClassification,
25
+ TrainingArguments,
26
+ Trainer
27
+ )
28
+
29
+ from peft import PeftModel
30
+
31
+ from datasets import Dataset
32
+ from accelerate import Accelerator
33
+ # Imports specific to the custom peft lora model
34
+ from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
35
+
36
+ from plot_pdb import plot_struc
37
+
38
+ def suggest(option):
39
+ if option == "Plastic degradation protein":
40
+ suggestion = "MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ"
41
+ elif option == "Default protein":
42
+ #suggestion = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
43
+ suggestion = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"
44
+ elif option == "Antifreeze protein":
45
+ suggestion = "QCTGGADCTSCTGACTGCGNCPNAVTCTNSQHCVKANTCTGSTDCNTAQTCTNSKDCFEANTCTDSTNCYKATACTNSSGCPGH"
46
+ elif option == "AI Generated protein":
47
+ suggestion = "MSGMKKLYEYTVTTLDEFLEKLKEFILNTSKDKIYKLTITNPKLIKDIGKAIAKAAEIADVDPKEIEEMIKAVEENELTKLVITIEQTDDKYVIKVELENEDGLVHSFEIYFKNKEEMEKFLELLEKLISKLSGS"
48
+ elif option == "7-bladed propeller fold":
49
+ suggestion = "VKLAGNSSLCPINGWAVYSKDNSIRIGSKGDVFVIREPFISCSHLECRTFFLTQGALLNDKHSNGTVKDRSPHRTLMSCPVGEAPSPYNSRFESVAWSASACHDGTSWLTIGISGPDNGAVAVLKYNGIITDTIKSWRNNILRTQESECACVNGSCFTVMTDGPSNGQASYKIFKMEKGKVVKSVELDAPNYHYEECSCYPNAGEITCVCRDNWHGSNRPWVSFNQNLEYQIGYICSGVFGDNPRPNDGTGSCGPVSSNGAYGVKGFSFKYGNGVWIGRTKSTNSRSGFEMIWDPNGWTETDSSFSVKQDIVAITDWSGYSGSFVQHPELTGLDCIRPCFWVELIRGRPKESTIWTSGSSISFCGVNSDTVGWSWPDGAELPFTIDK"
50
+ else:
51
+ suggestion = ""
52
+ return suggestion
53
+
54
+ # Helper Functions and Data Preparation
55
+ def truncate_labels(labels, max_length):
56
+ """Truncate labels to the specified max_length."""
57
+ return [label[:max_length] for label in labels]
58
+
59
+ def compute_metrics(p):
60
+ """Compute metrics for evaluation."""
61
+ predictions, labels = p
62
+ predictions = np.argmax(predictions, axis=2)
63
+
64
+ # Remove padding (-100 labels)
65
+ predictions = predictions[labels != -100].flatten()
66
+ labels = labels[labels != -100].flatten()
67
+
68
+ # Compute accuracy
69
+ accuracy = accuracy_score(labels, predictions)
70
+
71
+ # Compute precision, recall, F1 score, and AUC
72
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
73
+ auc = roc_auc_score(labels, predictions)
74
+
75
+ # Compute MCC
76
+ mcc = matthews_corrcoef(labels, predictions)
77
+
78
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
79
+
80
+ def compute_loss(model, inputs):
81
+ """Custom compute_loss function."""
82
+ logits = model(**inputs).logits
83
+ labels = inputs["labels"]
84
+ loss_fct = nn.CrossEntropyLoss(weight=class_weights)
85
+ active_loss = inputs["attention_mask"].view(-1) == 1
86
+ active_logits = logits.view(-1, model.config.num_labels)
87
+ active_labels = torch.where(
88
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
89
+ )
90
+ loss = loss_fct(active_logits, active_labels)
91
+ return loss
92
+
93
+ # Define Custom Trainer Class
94
+ # Since we are using class weights, due to the imbalance between non-binding residues and binding residues, we will need a custom weighted trainer.
95
+ class WeightedTrainer(Trainer):
96
+ def compute_loss(self, model, inputs, return_outputs=False):
97
+ outputs = model(**inputs)
98
+ loss = compute_loss(model, inputs)
99
+ return (loss, outputs) if return_outputs else loss
100
+
101
+ # Predict binding site with finetuned PEFT model
102
+ def predict_bind(base_model_path,PEFT_model_path,input_seq):
103
+ # Load the model
104
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
105
+ loaded_model = PeftModel.from_pretrained(base_model, PEFT_model_path)
106
+
107
+ # Ensure the model is in evaluation mode
108
+ loaded_model.eval()
109
+
110
+ # Tokenization
111
+ tokenizer = AutoTokenizer.from_pretrained(base_model_path)
112
+
113
+ # Tokenize the sequence
114
+ inputs = tokenizer(input_seq, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
115
+
116
+ # Run the model
117
+ with torch.no_grad():
118
+ logits = loaded_model(**inputs).logits
119
+
120
+ # Get predictions
121
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
122
+ predictions = torch.argmax(logits, dim=2)
123
+
124
+ binding_site=[]
125
+ pos = 0
126
+ # Print the predicted labels for each token
127
+ for token, prediction in zip(tokens, predictions[0].numpy()):
128
+ if token not in ['<pad>', '<cls>', '<eos>']:
129
+ pos += 1
130
+ print((pos, token, id2label[prediction]))
131
+ if prediction == 1:
132
+ print((pos, token, id2label[prediction]))
133
+ binding_site.append([pos, token, id2label[prediction]])
134
+
135
+ return binding_site
136
+
137
+ # fine-tuning function
138
+ def train_function_no_sweeps(base_model_path): #, train_dataset, test_dataset):
139
+
140
+ # Set the LoRA config
141
+ config = {
142
+ "lora_alpha": 1, #try 0.5, 1, 2, ..., 16
143
+ "lora_dropout": 0.2,
144
+ "lr": 5.701568055793089e-04,
145
+ "lr_scheduler_type": "cosine",
146
+ "max_grad_norm": 0.5,
147
+ "num_train_epochs": 1, #3, jw 20240628
148
+ "per_device_train_batch_size": 12,
149
+ "r": 2,
150
+ "weight_decay": 0.2,
151
+ # Add other hyperparameters as needed
152
+ }
153
+
154
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path, num_labels=len(id2label), id2label=id2label, label2id=label2id)
155
+
156
+ # Tokenization
157
+ tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D")
158
+
159
+ train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
160
+ test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
161
+
162
+ train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
163
+ test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
164
+
165
+ # Convert the model into a PeftModel
166
+ peft_config = LoraConfig(
167
+ task_type=TaskType.TOKEN_CLS,
168
+ inference_mode=False,
169
+ r=config["r"],
170
+ lora_alpha=config["lora_alpha"],
171
+ target_modules=["query", "key", "value"], # also try "dense_h_to_4h" and "dense_4h_to_h"
172
+ lora_dropout=config["lora_dropout"],
173
+ bias="none" # or "all" or "lora_only"
174
+ )
175
+ base_model = get_peft_model(base_model, peft_config)
176
+
177
+ # Use the accelerator
178
+ base_model = accelerator.prepare(base_model)
179
+ train_dataset = accelerator.prepare(train_dataset)
180
+ test_dataset = accelerator.prepare(test_dataset)
181
+
182
+ model_name_base = base_model_path.split("/")[1]
183
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H')
184
+
185
+ # Training setup
186
+ training_args = TrainingArguments(
187
+ output_dir=f"{model_name_base}-lora-binding-sites_{timestamp}",
188
+ learning_rate=config["lr"],
189
+ lr_scheduler_type=config["lr_scheduler_type"],
190
+ gradient_accumulation_steps=1,
191
+ max_grad_norm=config["max_grad_norm"],
192
+ per_device_train_batch_size=config["per_device_train_batch_size"],
193
+ per_device_eval_batch_size=config["per_device_train_batch_size"],
194
+ num_train_epochs=config["num_train_epochs"],
195
+ weight_decay=config["weight_decay"],
196
+ evaluation_strategy="epoch",
197
+ save_strategy="epoch",
198
+ load_best_model_at_end=True,
199
+ metric_for_best_model="f1",
200
+ greater_is_better=True,
201
+ push_to_hub=True, #jw 20240701 False,
202
+ logging_dir=None,
203
+ logging_first_step=False,
204
+ logging_steps=200,
205
+ save_total_limit=7,
206
+ no_cuda=False,
207
+ seed=8893,
208
+ fp16=True,
209
+ #report_to='wandb'
210
+ report_to=None,
211
+ hub_token = HF_TOKEN, #jw 20240701
212
+ )
213
+
214
+ # Initialize Trainer
215
+ trainer = WeightedTrainer(
216
+ model=base_model,
217
+ args=training_args,
218
+ train_dataset=train_dataset,
219
+ eval_dataset=test_dataset,
220
+ tokenizer=tokenizer,
221
+ data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
222
+ compute_metrics=compute_metrics,
223
+
224
+ )
225
+
226
+ # Train and Save Model
227
+ trainer.train()
228
+
229
+ return save_path
230
+
231
+ # Constants & Globals
232
+ HF_TOKEN = os.environ.get("HF_token")
233
+ print("HF_TOKEN:",HF_TOKEN)
234
+
235
+ MODEL_OPTIONS = [
236
+ "facebook/esm2_t6_8M_UR50D",
237
+ "facebook/esm2_t12_35M_UR50D",
238
+ "facebook/esm2_t33_650M_UR50D",
239
+ ] # models users can choose from
240
+
241
+ PEFT_MODEL_OPTIONS = [
242
+ "wangjin2000/esm2_t6_8M-lora-binding-sites_2024-07-02_09-26-54",
243
+ "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3",
244
+ ] # finetuned models
245
+
246
+
247
+ # Load the data from pickle files (replace with your local paths)
248
+ with open("./datasets/train_sequences_chunked_by_family.pkl", "rb") as f:
249
+ train_sequences = pickle.load(f)
250
+
251
+ with open("./datasets/test_sequences_chunked_by_family.pkl", "rb") as f:
252
+ test_sequences = pickle.load(f)
253
+
254
+ with open("./datasets/train_labels_chunked_by_family.pkl", "rb") as f:
255
+ train_labels = pickle.load(f)
256
+
257
+ with open("./datasets/test_labels_chunked_by_family.pkl", "rb") as f:
258
+ test_labels = pickle.load(f)
259
+
260
+ max_sequence_length = 1000
261
+
262
+ # Directly truncate the entire list of labels
263
+ train_labels = truncate_labels(train_labels, max_sequence_length)
264
+ test_labels = truncate_labels(test_labels, max_sequence_length)
265
+
266
+ # Compute Class Weights
267
+ classes = [0, 1]
268
+ flat_train_labels = [label for sublist in train_labels for label in sublist]
269
+ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
270
+ accelerator = Accelerator()
271
+ class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
272
+
273
+ # Define labels and model
274
+ id2label = {0: "No binding site", 1: "Binding site"}
275
+ label2id = {v: k for k, v in id2label.items()}
276
+
277
+ '''
278
+ # debug result
279
+ dubug_result = saved_path #predictions #class_weights
280
+ '''
281
+
282
+ demo = gr.Blocks(title="DEMO FOR ESM2Bind")
283
+
284
+ with demo:
285
+ gr.Markdown("# DEMO FOR ESM2Bind")
286
+ #gr.Textbox(dubug_result)
287
+
288
+ with gr.Column():
289
+ gr.Markdown("## Select a base model and a corresponding PEFT finetune model")
290
+
291
+ with gr.Row():
292
+ with gr.Column(scale=5, variant="compact"):
293
+ base_model_name = gr.Dropdown(
294
+ choices=MODEL_OPTIONS,
295
+ value=MODEL_OPTIONS[0],
296
+ label="Base Model Name",
297
+ interactive = True,
298
+ )
299
+ PEFT_model_name = gr.Dropdown(
300
+ choices=PEFT_MODEL_OPTIONS,
301
+ value=PEFT_MODEL_OPTIONS[0],
302
+ label="PEFT Model Name",
303
+ interactive = True,
304
+ )
305
+ with gr.Column(scale=5, variant="compact"):
306
+ name = gr.Dropdown(
307
+ label="Choose a Sample Protein",
308
+ value="Default protein",
309
+ choices=["Default protein", "Antifreeze protein", "Plastic degradation protein", "AI Generated protein", "7-bladed propeller fold", "custom"]
310
+ )
311
+ gr.Markdown(
312
+ "## Predict binding site and Plot structure for selected protein sequence:"
313
+ )
314
+ with gr.Row():
315
+ with gr.Column(variant="compact", scale = 8):
316
+ input_seq = gr.Textbox(
317
+ lines=1,
318
+ max_lines=12,
319
+ label="Protein sequency to be predicted:",
320
+ value="MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT",
321
+ placeholder="Paste your protein sequence here...",
322
+ interactive = True,
323
+ )
324
+ text_pos = gr.Textbox(
325
+ lines=1,
326
+ max_lines=12,
327
+ label="Sequency Position:",
328
+ placeholder=
329
+ "012345678911234567892123456789312345678941234567895123456789612345678971234567898123456789912345678901234567891123456789",
330
+ interactive=False,
331
+ )
332
+ with gr.Column(variant="compact", scale = 2):
333
+ predict_btn = gr.Button(
334
+ value="Predict binding site",
335
+ interactive=True,
336
+ variant="primary",
337
+ )
338
+ plot_struc_btn = gr.Button(value = "Plot ESMFold Predicted Structure ", variant="primary")
339
+ with gr.Row():
340
+ with gr.Column(variant="compact", scale = 5):
341
+ output_text = gr.Textbox(
342
+ lines=1,
343
+ max_lines=12,
344
+ label="Output",
345
+ placeholder="Output",
346
+ )
347
+ with gr.Column(variant="compact", scale = 5):
348
+ finetune_button = gr.Button(
349
+ value="Finetune Pre-trained Model",
350
+ interactive=True,
351
+ variant="primary",
352
+ )
353
+ with gr.Row():
354
+ output_viewer = gr.HTML()
355
+ output_file = gr.File(
356
+ label="Download as Text File",
357
+ file_count="single",
358
+ type="filepath",
359
+ interactive=False,
360
+ )
361
+
362
+ # select protein sample
363
+ name.change(fn=suggest, inputs=name, outputs=input_seq)
364
+
365
+ # "Predict binding site" actions
366
+ predict_btn.click(
367
+ fn = predict_bind,
368
+ inputs=[base_model_name,PEFT_model_name,input_seq],
369
+ outputs = [output_text],
370
+ )
371
+
372
+ # "Finetune Pre-trained Model" actions
373
+ finetune_button.click(
374
+ fn = train_function_no_sweeps,
375
+ inputs=[base_model_name],
376
+ outputs = [output_text],
377
+ )
378
+
379
+ # plot protein structure
380
+ plot_struc_btn.click(fn=plot_struc, inputs=input_seq, outputs=[output_file, output_viewer])
381
+
382
+
383
+ demo.launch()