Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -36,13 +36,14 @@ PEFT_MODEL_OPTIONS = [
|
|
| 36 |
|
| 37 |
#build datasets
|
| 38 |
class ProteinDataset(Dataset):
|
| 39 |
-
def __init__(self, file, tokenizer):
|
| 40 |
data = pd.read_csv(file)
|
| 41 |
self.tokenizer = tokenizer
|
| 42 |
self.proteins = data["Receptor Sequence"].tolist()
|
| 43 |
self.peptides = data["Binder"].tolist()
|
| 44 |
#self.proteins = data["P_Sequence"].tolist() #header defined by Lin Qiao
|
| 45 |
#self.peptides = data["p_Sequence"].tolist()
|
|
|
|
| 46 |
|
| 47 |
def __len__(self):
|
| 48 |
return len(self.proteins)
|
|
@@ -55,14 +56,14 @@ class ProteinDataset(Dataset):
|
|
| 55 |
complex_seq = protein_seq + masked_peptide
|
| 56 |
|
| 57 |
# Tokenize and pad the complex sequence
|
| 58 |
-
complex_input = self.tokenizer(complex_seq, return_tensors="pt", padding="max_length", max_length =
|
| 59 |
|
| 60 |
input_ids = complex_input["input_ids"].squeeze()
|
| 61 |
attention_mask = complex_input["attention_mask"].squeeze()
|
| 62 |
|
| 63 |
# Create labels (tokens for ground truth AAs)
|
| 64 |
label_seq = protein_seq + peptide_seq
|
| 65 |
-
labels = self.tokenizer(label_seq, return_tensors="pt", padding="max_length", max_length =
|
| 66 |
|
| 67 |
# Set non-masked positions in the labels tensor to -100
|
| 68 |
labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
|
|
@@ -70,7 +71,7 @@ class ProteinDataset(Dataset):
|
|
| 70 |
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
| 71 |
|
| 72 |
# fine-tuning function
|
| 73 |
-
def finetune(base_model_path): #, train_dataset, test_dataset):
|
| 74 |
|
| 75 |
#load base model
|
| 76 |
base_model = EsmForMaskedLM.from_pretrained(base_model_path)
|
|
@@ -78,9 +79,8 @@ def finetune(base_model_path): #, train_dataset, test_dataset):
|
|
| 78 |
# Tokenization
|
| 79 |
tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D")
|
| 80 |
|
| 81 |
-
train_dataset = ProteinDataset("./datasets/pepnn_train.csv", tokenizer)
|
| 82 |
-
test_dataset = ProteinDataset("./datasets/pepnn_test.csv", tokenizer)
|
| 83 |
-
print("line 84 testset:",test_dataset)
|
| 84 |
|
| 85 |
model_name_base = base_model_path.split("/")[1]
|
| 86 |
timestamp = datetime.now().strftime('%Y-%m-%d_%H')
|
|
@@ -179,7 +179,7 @@ def generate_peptide_for_single_sequence(model, tokenizer, protein_seq, peptide_
|
|
| 179 |
return binders_with_ppl
|
| 180 |
|
| 181 |
# Predict peptide binder with finetuned model
|
| 182 |
-
def predict_peptide(base_model_path, finetuned_model_path, input_seqs, peptide_length=15,
|
| 183 |
# Load the model
|
| 184 |
loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path)
|
| 185 |
|
|
@@ -211,7 +211,7 @@ def predict_peptide(base_model_path, finetuned_model_path, input_seqs, peptide_l
|
|
| 211 |
|
| 212 |
return results_df, PPC
|
| 213 |
|
| 214 |
-
def predict_peptide_from_file(base_model_path, finetuned_model_path,
|
| 215 |
# Load the model
|
| 216 |
loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path)
|
| 217 |
|
|
@@ -221,6 +221,12 @@ def predict_peptide_from_file(base_model_path, finetuned_model_path, input_seqs,
|
|
| 221 |
# Tokenization
|
| 222 |
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
if isinstance(input_seqs, str): # Single sequence
|
| 225 |
binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, input_seqs, peptide_length, top_k, num_binders)
|
| 226 |
results_df = pd.DataFrame(binders, columns=['Binder', 'Pseudo Perplexity'])
|
|
@@ -296,7 +302,7 @@ with demo:
|
|
| 296 |
file_count="single",
|
| 297 |
file_types=[".tsv", ".csv"],
|
| 298 |
type="filepath",
|
| 299 |
-
height=
|
| 300 |
)
|
| 301 |
gr.Markdown(
|
| 302 |
"## Predict peptide sequence:"
|
|
@@ -321,11 +327,17 @@ with demo:
|
|
| 321 |
)
|
| 322 |
with gr.Column(variant="compact", scale = 2):
|
| 323 |
predict_btn = gr.Button(
|
| 324 |
-
value="Predict peptide sequence",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
interactive=True,
|
| 326 |
variant="primary",
|
| 327 |
)
|
| 328 |
-
plot_struc_btn = gr.Button(value = "Plot ESMFold Predicted Structure ", variant="primary")
|
| 329 |
with gr.Row():
|
| 330 |
with gr.Column(variant="compact", scale = 5):
|
| 331 |
output_text = gr.Textbox(
|
|
@@ -340,11 +352,6 @@ with demo:
|
|
| 340 |
interactive=True,
|
| 341 |
variant="primary",
|
| 342 |
)
|
| 343 |
-
predict_file_btn = gr.Button(
|
| 344 |
-
value="Predict peptide from a local file",
|
| 345 |
-
interactive=True,
|
| 346 |
-
variant="primary",
|
| 347 |
-
)
|
| 348 |
with gr.Row():
|
| 349 |
output_viewer = gr.HTML()
|
| 350 |
output_file = gr.File(
|
|
|
|
| 36 |
|
| 37 |
#build datasets
|
| 38 |
class ProteinDataset(Dataset):
|
| 39 |
+
def __init__(self, file, tokenizer, peptide_length):
|
| 40 |
data = pd.read_csv(file)
|
| 41 |
self.tokenizer = tokenizer
|
| 42 |
self.proteins = data["Receptor Sequence"].tolist()
|
| 43 |
self.peptides = data["Binder"].tolist()
|
| 44 |
#self.proteins = data["P_Sequence"].tolist() #header defined by Lin Qiao
|
| 45 |
#self.peptides = data["p_Sequence"].tolist()
|
| 46 |
+
self.max_length_pm = 500 + 2 + peptide_length #assume the maz length of protein is 500
|
| 47 |
|
| 48 |
def __len__(self):
|
| 49 |
return len(self.proteins)
|
|
|
|
| 56 |
complex_seq = protein_seq + masked_peptide
|
| 57 |
|
| 58 |
# Tokenize and pad the complex sequence
|
| 59 |
+
complex_input = self.tokenizer(complex_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True)
|
| 60 |
|
| 61 |
input_ids = complex_input["input_ids"].squeeze()
|
| 62 |
attention_mask = complex_input["attention_mask"].squeeze()
|
| 63 |
|
| 64 |
# Create labels (tokens for ground truth AAs)
|
| 65 |
label_seq = protein_seq + peptide_seq
|
| 66 |
+
labels = self.tokenizer(label_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True)["input_ids"].squeeze()
|
| 67 |
|
| 68 |
# Set non-masked positions in the labels tensor to -100
|
| 69 |
labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
|
|
|
|
| 71 |
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
| 72 |
|
| 73 |
# fine-tuning function
|
| 74 |
+
def finetune(base_model_path, peptide_length): #, train_dataset, test_dataset):
|
| 75 |
|
| 76 |
#load base model
|
| 77 |
base_model = EsmForMaskedLM.from_pretrained(base_model_path)
|
|
|
|
| 79 |
# Tokenization
|
| 80 |
tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D")
|
| 81 |
|
| 82 |
+
train_dataset = ProteinDataset("./datasets/pepnn_train.csv", tokenizer, peptide_length)
|
| 83 |
+
test_dataset = ProteinDataset("./datasets/pepnn_test.csv", tokenizer, peptide_length)
|
|
|
|
| 84 |
|
| 85 |
model_name_base = base_model_path.split("/")[1]
|
| 86 |
timestamp = datetime.now().strftime('%Y-%m-%d_%H')
|
|
|
|
| 179 |
return binders_with_ppl
|
| 180 |
|
| 181 |
# Predict peptide binder with finetuned model
|
| 182 |
+
def predict_peptide(base_model_path, finetuned_model_path, input_seqs, peptide_length=15, num_binders=4, top_k=3):
|
| 183 |
# Load the model
|
| 184 |
loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path)
|
| 185 |
|
|
|
|
| 211 |
|
| 212 |
return results_df, PPC
|
| 213 |
|
| 214 |
+
def predict_peptide_from_file(base_model_path, finetuned_model_path, file_obj, peptide_length=15, num_binders=4, top_k=3):
|
| 215 |
# Load the model
|
| 216 |
loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path)
|
| 217 |
|
|
|
|
| 221 |
# Tokenization
|
| 222 |
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
|
| 223 |
|
| 224 |
+
eval_dataset = ProteinDataset(file_obj, tokenizer, peptide_length)
|
| 225 |
+
print("eval_dataset:",eval_dataset)
|
| 226 |
+
|
| 227 |
+
input_seqs = eval_dataset["input_ids"]
|
| 228 |
+
print("line 228 - input_seqs:",input_seqs)
|
| 229 |
+
|
| 230 |
if isinstance(input_seqs, str): # Single sequence
|
| 231 |
binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, input_seqs, peptide_length, top_k, num_binders)
|
| 232 |
results_df = pd.DataFrame(binders, columns=['Binder', 'Pseudo Perplexity'])
|
|
|
|
| 302 |
file_count="single",
|
| 303 |
file_types=[".tsv", ".csv"],
|
| 304 |
type="filepath",
|
| 305 |
+
height=10,
|
| 306 |
)
|
| 307 |
gr.Markdown(
|
| 308 |
"## Predict peptide sequence:"
|
|
|
|
| 327 |
)
|
| 328 |
with gr.Column(variant="compact", scale = 2):
|
| 329 |
predict_btn = gr.Button(
|
| 330 |
+
value="Predict peptide sequence from a protein sequence",
|
| 331 |
+
interactive=True,
|
| 332 |
+
variant="primary",
|
| 333 |
+
)
|
| 334 |
+
plot_struc_btn = gr.Button(value = "Plot ESMFold predicted structure ", variant="primary")
|
| 335 |
+
|
| 336 |
+
predict_file_btn = gr.Button(
|
| 337 |
+
value="Predict peptide from a local file",
|
| 338 |
interactive=True,
|
| 339 |
variant="primary",
|
| 340 |
)
|
|
|
|
| 341 |
with gr.Row():
|
| 342 |
with gr.Column(variant="compact", scale = 5):
|
| 343 |
output_text = gr.Textbox(
|
|
|
|
| 352 |
interactive=True,
|
| 353 |
variant="primary",
|
| 354 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
with gr.Row():
|
| 356 |
output_viewer = gr.HTML()
|
| 357 |
output_file = gr.File(
|