wangjin2000 commited on
Commit
0c3de7a
·
verified ·
1 Parent(s): fa9d40f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -1
app.py CHANGED
@@ -2,7 +2,7 @@
2
  import gradio as gr
3
 
4
  import os
5
- from transformers import Trainer, TrainingArguments, AutoTokenizer, EsmForMaskedLM, AutoModelForMaskedLM, TrainerCallback
6
  import torch
7
  from torch.utils.data import DataLoader, Dataset, RandomSampler
8
  from torch.optim import AdamW
@@ -141,8 +141,27 @@ def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq):
141
  # Loss is already averaged by the model
142
  avg_loss = loss.item()
143
  pseudo_perplexity = np.exp(avg_loss)
 
144
  return pseudo_perplexity
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def generate_peptide_for_single_sequence(model, tokenizer, protein_seq, peptide_length = 15, top_k = 3, num_binders = 4):
147
 
148
  peptide_length = int(peptide_length)
@@ -171,8 +190,13 @@ def generate_peptide_for_single_sequence(model, tokenizer, protein_seq, peptide_
171
  predicted_indices = Categorical(probabilities).sample()
172
  predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)
173
  generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '')
 
174
  # Compute PPL for the generated binder
175
  ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder)
 
 
 
 
176
 
177
  # Add the generated binder and its PPL to the results list
178
  binders_with_ppl.append([generated_binder, ppl_value])
 
2
  import gradio as gr
3
 
4
  import os
5
+ from transformers import Trainer, TrainingArguments, AutoTokenizer, EsmForMaskedLM, AutoModelForMaskedLM, TrainerCallback, EsmForProteinFolding
6
  import torch
7
  from torch.utils.data import DataLoader, Dataset, RandomSampler
8
  from torch.optim import AdamW
 
141
  # Loss is already averaged by the model
142
  avg_loss = loss.item()
143
  pseudo_perplexity = np.exp(avg_loss)
144
+
145
  return pseudo_perplexity
146
 
147
+ def compute_avg_plddt(protein_seq, binder_seq):
148
+ # always the ESMFold model
149
+ model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
150
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
151
+
152
+ sequence = protein_seq + binder_seq
153
+ inputs = tokenizer(sequence, return_tensors='pt', add_special_tokens=False)
154
+
155
+ # Get model predictions
156
+ with torch.no_grad():
157
+ outputs = model(**inputs)
158
+ plddt = outputs.plddt
159
+ print("159: plddt:",plddt)
160
+
161
+ avg_plddt = plddt.mean()
162
+
163
+ return avg_plddt
164
+
165
  def generate_peptide_for_single_sequence(model, tokenizer, protein_seq, peptide_length = 15, top_k = 3, num_binders = 4):
166
 
167
  peptide_length = int(peptide_length)
 
190
  predicted_indices = Categorical(probabilities).sample()
191
  predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)
192
  generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '')
193
+
194
  # Compute PPL for the generated binder
195
  ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder)
196
+
197
+ # Get PLDDT from ESMFold model
198
+ plddt_value = compute_avg_plddt(protein_seq, generated_binder)
199
+
200
 
201
  # Add the generated binder and its PPL to the results list
202
  binders_with_ppl.append([generated_binder, ppl_value])