wangjin2000 commited on
Commit
db7f9c6
·
verified ·
1 Parent(s): 18227f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -17
app.py CHANGED
@@ -36,13 +36,14 @@ PEFT_MODEL_OPTIONS = [
36
 
37
  #build datasets
38
  class ProteinDataset(Dataset):
39
- def __init__(self, file, tokenizer):
40
  data = pd.read_csv(file)
41
  self.tokenizer = tokenizer
42
  self.proteins = data["Receptor Sequence"].tolist()
43
  self.peptides = data["Binder"].tolist()
44
  #self.proteins = data["P_Sequence"].tolist() #header defined by Lin Qiao
45
  #self.peptides = data["p_Sequence"].tolist()
 
46
 
47
  def __len__(self):
48
  return len(self.proteins)
@@ -55,14 +56,14 @@ class ProteinDataset(Dataset):
55
  complex_seq = protein_seq + masked_peptide
56
 
57
  # Tokenize and pad the complex sequence
58
- complex_input = self.tokenizer(complex_seq, return_tensors="pt", padding="max_length", max_length = 552, truncation=True)
59
 
60
  input_ids = complex_input["input_ids"].squeeze()
61
  attention_mask = complex_input["attention_mask"].squeeze()
62
 
63
  # Create labels (tokens for ground truth AAs)
64
  label_seq = protein_seq + peptide_seq
65
- labels = self.tokenizer(label_seq, return_tensors="pt", padding="max_length", max_length = 552, truncation=True)["input_ids"].squeeze()
66
 
67
  # Set non-masked positions in the labels tensor to -100
68
  labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
@@ -70,7 +71,7 @@ class ProteinDataset(Dataset):
70
  return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
71
 
72
  # fine-tuning function
73
- def finetune(base_model_path): #, train_dataset, test_dataset):
74
 
75
  #load base model
76
  base_model = EsmForMaskedLM.from_pretrained(base_model_path)
@@ -78,9 +79,8 @@ def finetune(base_model_path): #, train_dataset, test_dataset):
78
  # Tokenization
79
  tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D")
80
 
81
- train_dataset = ProteinDataset("./datasets/pepnn_train.csv", tokenizer)
82
- test_dataset = ProteinDataset("./datasets/pepnn_test.csv", tokenizer)
83
- print("line 84 testset:",test_dataset)
84
 
85
  model_name_base = base_model_path.split("/")[1]
86
  timestamp = datetime.now().strftime('%Y-%m-%d_%H')
@@ -179,7 +179,7 @@ def generate_peptide_for_single_sequence(model, tokenizer, protein_seq, peptide_
179
  return binders_with_ppl
180
 
181
  # Predict peptide binder with finetuned model
182
- def predict_peptide(base_model_path, finetuned_model_path, input_seqs, peptide_length=15, top_k=3, num_binders=4):
183
  # Load the model
184
  loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path)
185
 
@@ -211,7 +211,7 @@ def predict_peptide(base_model_path, finetuned_model_path, input_seqs, peptide_l
211
 
212
  return results_df, PPC
213
 
214
- def predict_peptide_from_file(base_model_path, finetuned_model_path, input_seqs, peptide_length=15, top_k=3, num_binders=4):
215
  # Load the model
216
  loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path)
217
 
@@ -221,6 +221,12 @@ def predict_peptide_from_file(base_model_path, finetuned_model_path, input_seqs,
221
  # Tokenization
222
  tokenizer = AutoTokenizer.from_pretrained(base_model_path)
223
 
 
 
 
 
 
 
224
  if isinstance(input_seqs, str): # Single sequence
225
  binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, input_seqs, peptide_length, top_k, num_binders)
226
  results_df = pd.DataFrame(binders, columns=['Binder', 'Pseudo Perplexity'])
@@ -296,7 +302,7 @@ with demo:
296
  file_count="single",
297
  file_types=[".tsv", ".csv"],
298
  type="filepath",
299
- height=20,
300
  )
301
  gr.Markdown(
302
  "## Predict peptide sequence:"
@@ -321,11 +327,17 @@ with demo:
321
  )
322
  with gr.Column(variant="compact", scale = 2):
323
  predict_btn = gr.Button(
324
- value="Predict peptide sequence",
 
 
 
 
 
 
 
325
  interactive=True,
326
  variant="primary",
327
  )
328
- plot_struc_btn = gr.Button(value = "Plot ESMFold Predicted Structure ", variant="primary")
329
  with gr.Row():
330
  with gr.Column(variant="compact", scale = 5):
331
  output_text = gr.Textbox(
@@ -340,11 +352,6 @@ with demo:
340
  interactive=True,
341
  variant="primary",
342
  )
343
- predict_file_btn = gr.Button(
344
- value="Predict peptide from a local file",
345
- interactive=True,
346
- variant="primary",
347
- )
348
  with gr.Row():
349
  output_viewer = gr.HTML()
350
  output_file = gr.File(
 
36
 
37
  #build datasets
38
  class ProteinDataset(Dataset):
39
+ def __init__(self, file, tokenizer, peptide_length):
40
  data = pd.read_csv(file)
41
  self.tokenizer = tokenizer
42
  self.proteins = data["Receptor Sequence"].tolist()
43
  self.peptides = data["Binder"].tolist()
44
  #self.proteins = data["P_Sequence"].tolist() #header defined by Lin Qiao
45
  #self.peptides = data["p_Sequence"].tolist()
46
+ self.max_length_pm = 500 + 2 + peptide_length #assume the maz length of protein is 500
47
 
48
  def __len__(self):
49
  return len(self.proteins)
 
56
  complex_seq = protein_seq + masked_peptide
57
 
58
  # Tokenize and pad the complex sequence
59
+ complex_input = self.tokenizer(complex_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True)
60
 
61
  input_ids = complex_input["input_ids"].squeeze()
62
  attention_mask = complex_input["attention_mask"].squeeze()
63
 
64
  # Create labels (tokens for ground truth AAs)
65
  label_seq = protein_seq + peptide_seq
66
+ labels = self.tokenizer(label_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True)["input_ids"].squeeze()
67
 
68
  # Set non-masked positions in the labels tensor to -100
69
  labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
 
71
  return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
72
 
73
  # fine-tuning function
74
+ def finetune(base_model_path, peptide_length): #, train_dataset, test_dataset):
75
 
76
  #load base model
77
  base_model = EsmForMaskedLM.from_pretrained(base_model_path)
 
79
  # Tokenization
80
  tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D")
81
 
82
+ train_dataset = ProteinDataset("./datasets/pepnn_train.csv", tokenizer, peptide_length)
83
+ test_dataset = ProteinDataset("./datasets/pepnn_test.csv", tokenizer, peptide_length)
 
84
 
85
  model_name_base = base_model_path.split("/")[1]
86
  timestamp = datetime.now().strftime('%Y-%m-%d_%H')
 
179
  return binders_with_ppl
180
 
181
  # Predict peptide binder with finetuned model
182
+ def predict_peptide(base_model_path, finetuned_model_path, input_seqs, peptide_length=15, num_binders=4, top_k=3):
183
  # Load the model
184
  loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path)
185
 
 
211
 
212
  return results_df, PPC
213
 
214
+ def predict_peptide_from_file(base_model_path, finetuned_model_path, file_obj, peptide_length=15, num_binders=4, top_k=3):
215
  # Load the model
216
  loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path)
217
 
 
221
  # Tokenization
222
  tokenizer = AutoTokenizer.from_pretrained(base_model_path)
223
 
224
+ eval_dataset = ProteinDataset(file_obj, tokenizer, peptide_length)
225
+ print("eval_dataset:",eval_dataset)
226
+
227
+ input_seqs = eval_dataset["input_ids"]
228
+ print("line 228 - input_seqs:",input_seqs)
229
+
230
  if isinstance(input_seqs, str): # Single sequence
231
  binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, input_seqs, peptide_length, top_k, num_binders)
232
  results_df = pd.DataFrame(binders, columns=['Binder', 'Pseudo Perplexity'])
 
302
  file_count="single",
303
  file_types=[".tsv", ".csv"],
304
  type="filepath",
305
+ height=10,
306
  )
307
  gr.Markdown(
308
  "## Predict peptide sequence:"
 
327
  )
328
  with gr.Column(variant="compact", scale = 2):
329
  predict_btn = gr.Button(
330
+ value="Predict peptide sequence from a protein sequence",
331
+ interactive=True,
332
+ variant="primary",
333
+ )
334
+ plot_struc_btn = gr.Button(value = "Plot ESMFold predicted structure ", variant="primary")
335
+
336
+ predict_file_btn = gr.Button(
337
+ value="Predict peptide from a local file",
338
  interactive=True,
339
  variant="primary",
340
  )
 
341
  with gr.Row():
342
  with gr.Column(variant="compact", scale = 5):
343
  output_text = gr.Textbox(
 
352
  interactive=True,
353
  variant="primary",
354
  )
 
 
 
 
 
355
  with gr.Row():
356
  output_viewer = gr.HTML()
357
  output_file = gr.File(