wangjin2000 commited on
Commit
cdb0d81
·
verified ·
1 Parent(s): 519edc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -137
app.py CHANGED
@@ -3,10 +3,13 @@ import gradio as gr
3
 
4
  import os
5
  from transformers import Trainer, TrainingArguments, AutoTokenizer, EsmForMaskedLM, TrainerCallback
6
- from torch.utils.data import DataLoader, Dataset, RandomSampler
7
- import pandas as pd
8
  import torch
 
9
  from torch.optim import AdamW
 
 
 
 
10
  #import wandb
11
  import numpy as np
12
  from datetime import datetime
@@ -114,17 +117,25 @@ def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq):
114
  sequence = protein_seq + binder_seq
115
  original_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device)
116
  length_of_binder = len(binder_seq)
117
-
 
 
 
118
  # Prepare a batch with each row having one masked token from the binder sequence
119
  masked_inputs = original_input.repeat(length_of_binder, 1)
120
  positions_to_mask = torch.arange(-length_of_binder - 1, -1, device=model.device)
121
-
 
 
122
  masked_inputs[torch.arange(length_of_binder), positions_to_mask] = tokenizer.mask_token_id
123
-
 
124
  # Prepare labels for the masked tokens
125
  labels = torch.full_like(masked_inputs, -100)
 
126
  labels[torch.arange(length_of_binder), positions_to_mask] = original_input[0, positions_to_mask]
127
-
 
128
  # Get model predictions and calculate loss
129
  with torch.no_grad():
130
  outputs = model(masked_inputs, labels=labels)
@@ -135,37 +146,7 @@ def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq):
135
  pseudo_perplexity = np.exp(avg_loss)
136
  return pseudo_perplexity
137
 
138
- # Alternative implementation: Use Loop
139
- def compute_pseudo_perplexity2(model, tokenizer, protein_seq, binder_seq):
140
- sequence = protein_seq + binder_seq
141
- tensor_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device)
142
- total_loss = 0
143
-
144
- # Loop through each token in the binder sequence
145
- for i in range(-len(binder_seq)-1, -1):
146
- # Create a copy of the original tensor
147
- masked_input = tensor_input.clone()
148
-
149
- # Mask one token at a time
150
- masked_input[0, i] = tokenizer.mask_token_id
151
- # Create labels
152
- labels = torch.full(tensor_input.shape, -100).to(model.device)
153
- labels[0, i] = tensor_input[0, i]
154
-
155
- # Get model prediction and loss
156
- with torch.no_grad():
157
- outputs = model(masked_input, labels=labels)
158
- total_loss += outputs.loss.item()
159
-
160
- # Calculate the average loss
161
- avg_loss = total_loss / len(binder_seq)
162
-
163
- # Calculate pseudo perplexity
164
- pseudo_perplexity = np.exp(avg_loss)
165
- return pseudo_perplexity
166
-
167
-
168
- def generate_peptide_for_single_sequence(protein_seq, peptide_length = 15, top_k = 3, num_binders = 4):
169
 
170
  peptide_length = int(peptide_length)
171
  top_k = int(top_k)
@@ -212,9 +193,36 @@ def generate_peptide(input_seqs, peptide_length=15, top_k=3, num_binders=4):
212
  for binder, ppl in binders:
213
  results.append([seq, binder, ppl])
214
  return pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'Pseudo Perplexity'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  def suggest(option):
216
- if option == "Plastic degradation protein":
217
- suggestion = "MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFMDNDTRYSTFACENPNSTRVSDFRTANCSLEDPAANKARKEAELAAATAEQ"
218
  elif option == "Default protein":
219
  #suggestion = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
220
  suggestion = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"
@@ -227,97 +235,6 @@ def suggest(option):
227
  else:
228
  suggestion = ""
229
  return suggestion
230
-
231
- # Helper Functions and Data Preparation
232
- def truncate_labels(labels, max_length):
233
- """Truncate labels to the specified max_length."""
234
- return [label[:max_length] for label in labels]
235
-
236
- def compute_metrics(p):
237
- """Compute metrics for evaluation."""
238
- predictions, labels = p
239
- predictions = np.argmax(predictions, axis=2)
240
-
241
- # Remove padding (-100 labels)
242
- predictions = predictions[labels != -100].flatten()
243
- labels = labels[labels != -100].flatten()
244
-
245
- # Compute accuracy
246
- accuracy = accuracy_score(labels, predictions)
247
-
248
- # Compute precision, recall, F1 score, and AUC
249
- precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
250
- auc = roc_auc_score(labels, predictions)
251
-
252
- # Compute MCC
253
- mcc = matthews_corrcoef(labels, predictions)
254
-
255
- return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
256
-
257
- def compute_loss(model, inputs):
258
- """Custom compute_loss function."""
259
- logits = model(**inputs).logits
260
- labels = inputs["labels"]
261
- loss_fct = nn.CrossEntropyLoss(weight=class_weights)
262
- active_loss = inputs["attention_mask"].view(-1) == 1
263
- active_logits = logits.view(-1, model.config.num_labels)
264
- active_labels = torch.where(
265
- active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
266
- )
267
- loss = loss_fct(active_logits, active_labels)
268
- return loss
269
-
270
- # Predict binding site with finetuned PEFT model
271
- def predict_bind(base_model_path,PEFT_model_path,input_seq):
272
- # Load the model
273
- base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
274
- loaded_model = PeftModel.from_pretrained(base_model, PEFT_model_path)
275
-
276
- # Ensure the model is in evaluation mode
277
- loaded_model.eval()
278
-
279
- # Tokenization
280
- tokenizer = AutoTokenizer.from_pretrained(base_model_path)
281
-
282
- # Tokenize the sequence
283
- inputs = tokenizer(input_seq, return_tensors="pt", truncation=True, max_length=1024, padding='max_length')
284
-
285
- # Run the model
286
- with torch.no_grad():
287
- logits = loaded_model(**inputs).logits
288
-
289
- # Get predictions
290
- tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
291
- predictions = torch.argmax(logits, dim=2)
292
-
293
- binding_site=[]
294
- pos = 0
295
- # Print the predicted labels for each token
296
- for token, prediction in zip(tokens, predictions[0].numpy()):
297
- if token not in ['<pad>', '<cls>', '<eos>']:
298
- pos += 1
299
- print((pos, token, id2label[prediction]))
300
- if prediction == 1:
301
- print((pos, token, id2label[prediction]))
302
- binding_site.append([pos, token, id2label[prediction]])
303
-
304
- return binding_site
305
-
306
- MODEL_OPTIONS = [
307
- "facebook/esm2_t6_8M_UR50D",
308
- "facebook/esm2_t12_35M_UR50D",
309
- "facebook/esm2_t33_650M_UR50D",
310
- ] # models users can choose from
311
-
312
- PEFT_MODEL_OPTIONS = [
313
- "wangjin2000/esm2_t6_8M-lora-binding-sites_2024-07-02_09-26-54",
314
- "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3",
315
- ] # finetuned models
316
-
317
- '''
318
- # debug result
319
- dubug_result = saved_path #predictions #class_weights
320
- '''
321
 
322
  demo = gr.Blocks(title="ESM2 for Protein-Protein Interaction (ESM2PPI)")
323
 
@@ -345,8 +262,8 @@ with demo:
345
  with gr.Column(scale=5, variant="compact"):
346
  name = gr.Dropdown(
347
  label="Choose a Sample Protein",
348
- value="Default protein",
349
- choices=["Default protein", "Antifreeze protein", "Plastic degradation protein", "AI Generated protein", "7-bladed propeller fold", "custom"]
350
  )
351
  gr.Markdown(
352
  "## Predict binding site and Plot structure for selected protein sequence:"
@@ -356,8 +273,8 @@ with demo:
356
  input_seq = gr.Textbox(
357
  lines=1,
358
  max_lines=12,
359
- label="Protein sequency to be predicted:",
360
- value="MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT",
361
  placeholder="Paste your protein sequence here...",
362
  interactive = True,
363
  )
@@ -371,7 +288,7 @@ with demo:
371
  )
372
  with gr.Column(variant="compact", scale = 2):
373
  predict_btn = gr.Button(
374
- value="Predict binding site",
375
  interactive=True,
376
  variant="primary",
377
  )
@@ -402,9 +319,9 @@ with demo:
402
  # select protein sample
403
  name.change(fn=suggest, inputs=name, outputs=input_seq)
404
 
405
- # "Predict binding site" actions
406
  predict_btn.click(
407
- fn = predict_bind,
408
  inputs=[base_model_name,PEFT_model_name,input_seq],
409
  outputs = [output_text],
410
  )
 
3
 
4
  import os
5
  from transformers import Trainer, TrainingArguments, AutoTokenizer, EsmForMaskedLM, TrainerCallback
 
 
6
  import torch
7
+ from torch.utils.data import DataLoader, Dataset, RandomSampler
8
  from torch.optim import AdamW
9
+ from torch.distributions import Categorical
10
+
11
+ import pandas as pd
12
+
13
  #import wandb
14
  import numpy as np
15
  from datetime import datetime
 
117
  sequence = protein_seq + binder_seq
118
  original_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device)
119
  length_of_binder = len(binder_seq)
120
+ print("length of original_input:",len(original_input))
121
+ print("length of binder:",length_of_binder)
122
+ print("original_input:",original_input)
123
+
124
  # Prepare a batch with each row having one masked token from the binder sequence
125
  masked_inputs = original_input.repeat(length_of_binder, 1)
126
  positions_to_mask = torch.arange(-length_of_binder - 1, -1, device=model.device)
127
+ print("masked_inputs:",masked_inputs)
128
+ print("positions_to_mask:",positions_to_mask)
129
+
130
  masked_inputs[torch.arange(length_of_binder), positions_to_mask] = tokenizer.mask_token_id
131
+ print("masked_inputs tokens:",masked_inputs[torch.arange(length_of_binder), positions_to_mask])
132
+
133
  # Prepare labels for the masked tokens
134
  labels = torch.full_like(masked_inputs, -100)
135
+ print("labels:",labels)
136
  labels[torch.arange(length_of_binder), positions_to_mask] = original_input[0, positions_to_mask]
137
+ print("labels 137:",labels)
138
+
139
  # Get model predictions and calculate loss
140
  with torch.no_grad():
141
  outputs = model(masked_inputs, labels=labels)
 
146
  pseudo_perplexity = np.exp(avg_loss)
147
  return pseudo_perplexity
148
 
149
+ def generate_peptide_for_single_sequence(model, tokenizer, protein_seq, peptide_length = 15, top_k = 3, num_binders = 4):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  peptide_length = int(peptide_length)
152
  top_k = int(top_k)
 
193
  for binder, ppl in binders:
194
  results.append([seq, binder, ppl])
195
  return pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'Pseudo Perplexity'])
196
+
197
+ # Predict peptide binder with finetuned model
198
+ def predict_peptide(base_model_path, finetuned_model_path, input_seqs, peptide_length=15, top_k=3, num_binders=4):
199
+ # Load the model
200
+ loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path)
201
+
202
+ # Ensure the model is in evaluation mode
203
+ loaded_model.eval()
204
+
205
+ # Tokenization
206
+ tokenizer = AutoTokenizer.from_pretrained(base_model_path)
207
+
208
+ if isinstance(input_seqs, str): # Single sequence
209
+ binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, input_seqs, peptide_length, top_k, num_binders)
210
+ resuls_df = pd.DataFrame(binders, columns=['Binder', 'Pseudo Perplexity'])
211
+
212
+ elif isinstance(input_seqs, list): # List of sequences
213
+ results = []
214
+ for seq in input_seqs:
215
+ binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, seq, peptide_length, top_k, num_binders)
216
+ for binder, ppl in binders:
217
+ results.append([seq, binder, ppl])
218
+ resuls_df = pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'Pseudo Perplexity'])
219
+ print(results_df)
220
+
221
+ return result_df
222
+
223
  def suggest(option):
224
+ if option == "Protein:P63279":
225
+ suggestion = "MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS"
226
  elif option == "Default protein":
227
  #suggestion = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
228
  suggestion = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"
 
235
  else:
236
  suggestion = ""
237
  return suggestion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  demo = gr.Blocks(title="ESM2 for Protein-Protein Interaction (ESM2PPI)")
240
 
 
262
  with gr.Column(scale=5, variant="compact"):
263
  name = gr.Dropdown(
264
  label="Choose a Sample Protein",
265
+ value="Protein:P63279",
266
+ choices=["Default protein", "Antifreeze protein", "Protein:P63279", "AI Generated protein", "7-bladed propeller fold", "custom"]
267
  )
268
  gr.Markdown(
269
  "## Predict binding site and Plot structure for selected protein sequence:"
 
273
  input_seq = gr.Textbox(
274
  lines=1,
275
  max_lines=12,
276
+ label="Protein:P63279 to be predicted:",
277
+ value="MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS",
278
  placeholder="Paste your protein sequence here...",
279
  interactive = True,
280
  )
 
288
  )
289
  with gr.Column(variant="compact", scale = 2):
290
  predict_btn = gr.Button(
291
+ value="Predict peptide sequence",
292
  interactive=True,
293
  variant="primary",
294
  )
 
319
  # select protein sample
320
  name.change(fn=suggest, inputs=name, outputs=input_seq)
321
 
322
+ # "Predict peptide sequence" actions
323
  predict_btn.click(
324
+ fn = predict_peptide,
325
  inputs=[base_model_name,PEFT_model_name,input_seq],
326
  outputs = [output_text],
327
  )