Spaces:
Paused
Paused
File size: 21,785 Bytes
dd4330e 2334c10 cbdb785 0c3de7a 9c1eb20 cdb0d81 9c1eb20 cdb0d81 ccb91c3 cdb0d81 8e6f926 9c1eb20 5620fc5 ce47dc1 f1bba3b 19d3e69 e381e89 19d3e69 519edc2 ac925ab 9c2e8dd 00f906e 4d87697 19d3e69 db7f9c6 19d3e69 ff7b8e3 e413339 a9f4b7f 19d3e69 752caab 19d3e69 a9f4b7f d825dd9 19d3e69 cd0b38f a9f4b7f 19d3e69 2d9177f 19d3e69 db7f9c6 5926c76 19d3e69 f4a25aa 19d3e69 752caab 19d3e69 ff7b8e3 2a6013d ff7b8e3 19d3e69 2aa5124 19d3e69 2aa5124 19d3e69 c7e53bf 19d3e69 7f2ddaa 19d3e69 2a6013d 19d3e69 519edc2 5620fc5 d63575c ae5cd29 519edc2 cdb0d81 519edc2 cdb0d81 519edc2 fa9d40f 519edc2 d82df2e 519edc2 f18f4a7 5620fc5 f18f4a7 7362684 0c3de7a 519edc2 4f846da 0c1f112 5620fc5 0c3de7a f4a25aa 0c3de7a ae5cd29 f4a25aa ae5cd29 f4a25aa 303be44 f4a25aa 303be44 0c3de7a f4a25aa 0c3de7a 96c2d1e 6948058 efd4eb0 9526b25 b730525 d89da81 0c3de7a 5620fc5 f18f4a7 7362684 f18f4a7 7362684 4f846da 6ead239 4f846da 334cdb5 4f846da 6ead239 4f846da d55ab1f 6ead239 4f846da 0c3de7a c7d6740 5620fc5 f18f4a7 519edc2 9526b25 aa0a8ad 519edc2 aa0a8ad f0e02fb 519edc2 c7d6740 ae5cd29 519edc2 fa9d40f 519edc2 aa0a8ad 519edc2 6ffebf7 06f1a01 519edc2 0c3de7a 519edc2 9161c6a 0c3de7a 4f846da b596af3 4f846da c7d6740 aa0a8ad 519edc2 3da800a 519edc2 5620fc5 f18f4a7 5539efa f18f4a7 9526b25 519edc2 cdb0d81 287064c cdb0d81 c7d6740 cdb0d81 287064c cdb0d81 c7d6740 efd4eb0 cdb0d81 c7d6740 07768ea 31b25c9 97580aa 6ffebf7 97580aa b730525 06ca46a 3730722 53edd52 ddfefab 871c9c4 53edd52 c7d6740 53edd52 7753a80 5539efa ff7b8e3 63dab7e 4a62eed 4e5ecd0 19d3dce b594185 19d3dce c7d6740 4a62eed 033b440 4a62eed ddfefab 4a62eed c7d6740 6a8b48e 4a62eed 5d04074 f5eb425 4a62eed 19251cb fb8afd6 19251cb fb8afd6 53edd52 9769ada fb8afd6 871c9c4 a982de1 92cd7a7 2334c10 cdb0d81 2334c10 d1f6cab 2334c10 d1f6cab 2334c10 3039822 2334c10 3039822 2334c10 9f3ddc0 fa9d40f 6fad7e1 c7d6740 2334c10 cdb0d81 2334c10 f7a164f 92cd7a7 f7a164f 2334c10 8074cb1 2334c10 cdb0d81 2334c10 db7f9c6 2334c10 cdb0d81 2334c10 cdb0d81 c7d6740 8074cb1 86e36a2 8074cb1 c7d6740 86e36a2 2334c10 9fa463b caecccc 2334c10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 |
#ref: https://huggingface.co/blog/AmelieSchreiber/esm-interact
import gradio as gr
import os
from transformers import Trainer, TrainingArguments, AutoTokenizer, EsmForMaskedLM, AutoModelForMaskedLM, TrainerCallback, EsmForProteinFolding
import torch
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.optim import AdamW
from torch.distributions import Categorical
import pandas as pd
from pathlib import Path
#import wandb
import numpy as np
from datetime import datetime
import time
from plot_pdb import plot_struc
import requests
import Bio.PDB
# Constants & Globals
HF_TOKEN = os.environ.get("HF_token")
print("HF_TOKEN:",HF_TOKEN)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_OPTIONS = [
"facebook/esm2_t6_8M_UR50D",
"facebook/esm2_t12_35M_UR50D",
"facebook/esm2_t33_650M_UR50D",
] # models users can choose from
PEFT_MODEL_OPTIONS = [
"wangjin2000/esm2_t6_8M_UR50D_PPI_2024-09-19_06",
"wangjin2000/esm2_t12_35M_UR50D_PPI_2024-09-19_09",
"wangjin2000/esm2_t33_650M_UR50D_PPI_2024-09-19_10",
"wangjin2000/esm2_t6_8M_UR50D_PPI_2024-11-19_04",
"facebook/esm2_t6_8M_UR50D",
"facebook/esm2_t12_35M_UR50D",
"facebook/esm2_t33_650M_UR50D",
] # finetuned models, but the last three models are original fundation models
#build datasets
class ProteinDataset(Dataset):
def __init__(self, file, tokenizer, peptide_length):
data = pd.read_csv(file)
self.tokenizer = tokenizer
#self.proteins = data["Receptor Sequence"].tolist()
#self.peptides = data["Binder"].tolist()
self.proteins = data["P_Sequence"].tolist() #header defined by Lin Qiao
self.peptides = data["p_Sequence"].tolist()
self.max_length_pm = 1000 + 2 + peptide_length #assume the max length of protein is 1000
def __len__(self):
return len(self.proteins)
def __getitem__(self, idx):
protein_seq = self.proteins[idx]
peptide_seq = self.peptides[idx]
masked_peptide = '<mask>' * len(peptide_seq)
complex_seq = protein_seq + masked_peptide
# Tokenize and pad the complex sequence
complex_input = self.tokenizer(complex_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True).to(device)
input_ids = complex_input["input_ids"].squeeze()
attention_mask = complex_input["attention_mask"].squeeze()
# Create labels (tokens for ground truth AAs)
label_seq = protein_seq + peptide_seq
labels = self.tokenizer(label_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True)["input_ids"].to(device).squeeze()
# Set non-masked positions in the labels tensor to -100
labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
# fine-tuning function
def finetune(base_model_path, peptide_length): #, train_dataset, test_dataset):
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#load base model
base_model = EsmForMaskedLM.from_pretrained(base_model_path).to(device)
# Tokenization
tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D")
#train_dataset = ProteinDataset("./datasets/pepnn_train.csv", tokenizer, peptide_length)
#test_dataset = ProteinDataset("./datasets/pepnn_test.csv", tokenizer, peptide_length)
train_dataset = ProteinDataset("./datasets/peptide-protein-train.csv", tokenizer, peptide_length)
eval_dataset = ProteinDataset("./datasets/peptide-protein-eval.csv", tokenizer, peptide_length)
model_name_base = base_model_path.split("/")[1]
timestamp = datetime.now().strftime('%Y-%m-%d_%H')
lr = 0.0007984276816171436
save_path = f"{model_name_base}_PPI_{timestamp}"
training_args = TrainingArguments(
output_dir=save_path, #f"{model_name_base}_PPI_{timestamp}",
num_train_epochs = 5,
per_device_train_batch_size = 2,
per_device_eval_batch_size = 16,
warmup_steps = 501,
logging_dir=None,
logging_steps=10,
evaluation_strategy="epoch",
load_best_model_at_end=True,
save_strategy='epoch',
metric_for_best_model='eval_loss',
save_total_limit = 5,
gradient_accumulation_steps=2,
push_to_hub=True, #jw 20240918
hub_token = HF_TOKEN, #jw 20240918
dataloader_pin_memory=False, #jw 20241119 true for CPU
)
# Initialize Trainer
trainer = Trainer(
model=base_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
optimizers=(AdamW(base_model.parameters(), lr=lr), None),
)
# Train and Save Model
trainer.train()
return save_path
def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq):
start = time.time()
sequence = protein_seq + binder_seq
print("129:model.device in compute_pseudo_perplexity",model.device)
original_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device)
length_of_binder = len(binder_seq)
# Prepare a batch with each row having one masked token from the binder sequence
masked_inputs = original_input.repeat(length_of_binder, 1)
positions_to_mask = torch.arange(-length_of_binder - 1, -1, device=model.device)
masked_inputs[torch.arange(length_of_binder), positions_to_mask] = tokenizer.mask_token_id
# Prepare labels for the masked tokens
labels = torch.full_like(masked_inputs, -100)
labels[torch.arange(length_of_binder), positions_to_mask] = original_input[0, positions_to_mask]
# Get model predictions and calculate loss
with torch.no_grad():
outputs = model(masked_inputs, labels=labels)
loss = outputs.loss
# Loss is already averaged by the model
avg_loss = loss.item()
pseudo_perplexity = np.exp(avg_loss)
end = time.time()
elapsed = end - start
#print(f'compute_pseudo_perplexity time: {elapsed:.4f} seconds')
return pseudo_perplexity
# compute pLLDT and iPMT from ESMFOLD model directly, very slow
def compute_plddt_iptm(protein_seq, binder_seq):
start = time.time()
# always the ESMFold model
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").to(device)
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model.eval()
# based on https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/protein_folding.ipynb
#model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
#model = model.cuda()
#model.esm = model.esm.half() # Uncomment to switch the stem to float16
#torch.backends.cuda.matmul.allow_tf32 = True
print("168:model.device in compute_plddt_iptm",model.device)
sequence = protein_seq + binder_seq
inputs = tokenizer(sequence, return_tensors='pt', add_special_tokens=False).to(device)
# Get model predictions
with torch.no_grad():
outputs = model(**inputs)
plddt = outputs.plddt
ptm = outputs.ptm.item()
avg_plddt = plddt[0,:,1].mean().item()
#iPTM = ptm
#print("170: iPTM:",iPTM)
end = time.time()
elapsed = end - start
#print(f'compute_plddt_iptm time: {elapsed:.4f} seconds')
return avg_plddt, ptm
# call web API of ESMFOLD to get pLLDT
def get_plddt(protein_seq, binder_seq):
start = time.time()
sequence = protein_seq + binder_seq
retries = 0
pdb_str = None
url = "https://api.esmatlas.com/foldSequence/v1/pdb/"
while retries < 3 and pdb_str is None:
response = requests.post(url, data=sequence, verify=False)
pdb_str = response.text
if pdb_str == "INTERNAL SERVER ERROR":
retries += 1
time.sleep(0.1)
pdb = None #pdb = str = None
#save a pdb format file
name = sequence[:3] + sequence[-3:] #combine the firt and last 3 AAs of sequence as a filename.
outpath = (
Path.cwd() / f"PDB-{name}.pdb")
with open(outpath.name, "w") as f:
f.write(pdb_str)
outpath_str = str(outpath)
#get pdb column values
p = Bio.PDB.PDBParser()
structure = p.get_structure('myStructureName', outpath_str)
ids = [a.get_id() for a in structure.get_atoms()]
pLDDTs = [a.get_bfactor() for a in structure.get_atoms()]
avg_plddt = np.mean(pLDDTs)
ptm = 0 #place holder for iPTM
end = time.time()
elapsed = end - start
print(f'get_plddt time: {elapsed:.4f} seconds')
return avg_plddt, ptm
def generate_peptide_for_single_sequence(model, tokenizer, protein_seq, peptide_length = 15, top_k = 3, num_binders = 5, plddt_iptm_yes="no"):
start = time.time()
peptide_length = int(peptide_length)
top_k = int(top_k)
num_binders = int(num_binders)
binders_with_ppl_plddt = []
n = 0
for _ in range(num_binders):
n += 1
#print("n in num_binders:", n)
# Generate binder
masked_peptide = '<mask>' * peptide_length
input_sequence = protein_seq + masked_peptide
inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device)
#inputs = tokenizer(input_sequence, return_tensors="pt").to(device)
print("198:model.device in generate_:",model.device)
with torch.no_grad():
logits = model(**inputs).logits
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
logits_at_masks = logits[0, mask_token_indices]
# Apply top-k sampling
top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1)
probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1)
predicted_indices = Categorical(probabilities).sample()
predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)
generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '')
# Compute PPL for the generated binder
ppl = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder)
# Get PLDDT from ESMFold model
if plddt_iptm_yes=="yes":
plddt, iptm = get_plddt(protein_seq, generated_binder)
#plddt, iptm = compute_plddt_iptm(protein_seq, generated_binder) #too time-consuming
else:
plddt, iptm = [0, 0]
# Add the generated binder and its PPL to the results list
binders_with_ppl_plddt.append([generated_binder, ppl, plddt, iptm])
end = time.time()
elapsed = end - start
print(f'generate_peptide_for_single_sequence: {elapsed:.4f} seconds')
return binders_with_ppl_plddt
# Predict peptide binder with finetuned model
def predict_peptide(base_model_path, finetuned_model_path, input_seqs, peptide_length=15, num_binders=4, plddt_iptm_yes="no"):
# Load the model
loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path) #.to(device) inference use cpu
# Ensure the model is in evaluation mode
loaded_model.eval()
# Tokenization
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# set top_k mutations for each AA position
top_k=3
if isinstance(input_seqs, str): # Single sequence
binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, input_seqs, peptide_length, top_k, num_binders, plddt_iptm_yes)
results_df = pd.DataFrame(binders, columns=['Binder', 'PPL', 'pLDDT', 'iPTM'])
elif isinstance(input_seqs, list): # List of sequences
results = []
for seq in input_seqs:
binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, seq, peptide_length, top_k, num_binders, plddt_iptm_yes)
for binder, ppl, plddt, iptm in binders:
results.append([seq, binder, ppl, plddt, iptm])
results_df = pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'PPL', 'pLDDT', 'iPTM'])
print(results_df)
#combine target protein and predicted peptide with 20 G amino acids.
separator = 'G' * 20
peptide_lp = results_df['Binder'][results_df['PPL'].idxmin()] #Choosing the one with the lowest perplexity
PPC = input_seqs + separator + peptide_lp
return results_df, PPC
def predict_peptide_from_file(base_model_path, finetuned_model_path, file_obj, max_peptide_length=15, num_binders=5, plddt_iptm_yes="no"):
start = time.time()
# Load the model
loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path) #.to(device)
# Ensure the model is in evaluation mode
loaded_model.eval()
# Tokenization
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
input = pd.read_csv(file_obj, header=0 )
results = []
for i, row in input.iterrows():
print("sequence:", i)
#protein_seq = row['Receptor Sequence']
#peptide_seq = row['Peptide Sequence']
protein_seq = row['P_Sequence']
peptide_seq = row['p_Sequence']
peptide_length = min([len(peptide_seq), max_peptide_length]) # use the same length of ground truth peptide length for prediction limited to max_peptide_length
#get metrics for ground truth peptide
ppl = compute_pseudo_perplexity(loaded_model, tokenizer, protein_seq, peptide_seq)
if plddt_iptm_yes=="yes":
plddt, iptm = get_plddt(protein_seq, peptide_seq)
#plddt, iptm = compute_plddt_iptm(protein_seq, peptide_seq) #too time-consuming
else:
plddt, iptm = [0, 0]
results.append([protein_seq, peptide_seq, ppl, plddt, iptm, 1]) # flag 1 for ground truth peptide
# set top_k mutations for each AA position
top_k=3
#predict peptides
binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, protein_seq, peptide_length, top_k, num_binders, plddt_iptm_yes)
for binder, ppl, plddt, iptm in binders:
results.append([protein_seq, binder, ppl, plddt, iptm, 0]) # flag 0 for generated peptide
results_df = pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'PPL', 'pLDDT', 'iPTM', 'GT_Flag'])
timestamp = datetime.now().strftime('%Y-%m-%d_%H')
outpath = (
Path.cwd() / f"predicted_peptides_{timestamp}.csv"
)
str_outpath = str(outpath) # work around as the latest gr.File needs a string_type input.
results_df.to_csv(str_outpath, header=True, index=False)
end = time.time()
elapsed = end - start
print(f'predict_peptide_from_file: {elapsed:.4f} seconds')
return str_outpath
def suggest(option):
if option == "Protein:P63279":
suggestion = "MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS"
elif option == "Default protein":
#suggestion = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
suggestion = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"
elif option == "Antifreeze protein":
suggestion = "QCTGGADCTSCTGACTGCGNCPNAVTCTNSQHCVKANTCTGSTDCNTAQTCTNSKDCFEANTCTDSTNCYKATACTNSSGCPGH"
elif option == "AI Generated protein":
suggestion = "MSGMKKLYEYTVTTLDEFLEKLKEFILNTSKDKIYKLTITNPKLIKDIGKAIAKAAEIADVDPKEIEEMIKAVEENELTKLVITIEQTDDKYVIKVELENEDGLVHSFEIYFKNKEEMEKFLELLEKLISKLSGS"
elif option == "7-bladed propeller fold":
suggestion = "VKLAGNSSLCPINGWAVYSKDNSIRIGSKGDVFVIREPFISCSHLECRTFFLTQGALLNDKHSNGTVKDRSPHRTLMSCPVGEAPSPYNSRFESVAWSASACHDGTSWLTIGISGPDNGAVAVLKYNGIITDTIKSWRNNILRTQESECACVNGSCFTVMTDGPSNGQASYKIFKMEKGKVVKSVELDAPNYHYEECSCYPNAGEITCVCRDNWHGSNRPWVSFNQNLEYQIGYICSGVFGDNPRPNDGTGSCGPVSSNGAYGVKGFSFKYGNGVWIGRTKSTNSRSGFEMIWDPNGWTETDSSFSVKQDIVAITDWSGYSGSFVQHPELTGLDCIRPCFWVELIRGRPKESTIWTSGSSISFCGVNSDTVGWSWPDGAELPFTIDK"
else:
suggestion = ""
return suggestion
demo = gr.Blocks(title="ESM2 for Micropertide Protein Interaction (ESM2MPI)")
with demo:
gr.Markdown("# ESM2 for Micropertide Protein Interaction (ESM2MPI)")
#gr.Textbox(dubug_result)
with gr.Column():
gr.Markdown("## Select a base model and a corresponding finetuned model")
with gr.Row():
with gr.Column(scale=5, variant="compact"):
base_model_name = gr.Dropdown(
choices=MODEL_OPTIONS,
value=MODEL_OPTIONS[0],
label="Base Model Name",
interactive = True,
)
PEFT_model_name = gr.Dropdown(
choices=PEFT_MODEL_OPTIONS,
value=PEFT_MODEL_OPTIONS[0],
label="Finetuned Model Name",
interactive = True,
)
with gr.Row():
peptide_length=gr.Slider(minimum=10, maximum=100, step=1, label="Peptide Maximum Length", value=15)
num_pred_peptides=gr.Slider(minimum=1, maximum=10, step=1, label="Number of Predicted Peptides", value=5)
plddt_iptm_yes=gr.Radio(["yes", "no"],label="Compute pLDDT and iPTM (slow!)", value="no")
with gr.Column(scale=5, variant="compact"):
name = gr.Dropdown(
label="Choose a Sample Protein",
value="Protein:P63279",
choices=["Default protein", "Antifreeze protein", "Protein:P63279", "AI Generated protein", "7-bladed propeller fold", "custom"]
)
uploaded_file = gr.File(
label="Local File Upload",
file_count="single",
file_types=[".tsv", ".csv"],
type="filepath",
height=40,
)
gr.Markdown(
"## Predict peptide sequence:"
)
with gr.Row():
with gr.Column(variant="compact", scale = 8):
input_seq = gr.Textbox(
lines=1,
max_lines=12,
label="Protein:P63279 to be predicted:",
value="MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS",
placeholder="Paste your protein sequence here...",
interactive = True,
)
text_pos = gr.Textbox(
lines=1,
max_lines=12,
label="Sequency Position:",
placeholder=
"012345678911234567892123456789312345678941234567895123456789612345678971234567898123456789912345678901234567891123456789",
interactive=False,
)
with gr.Column(variant="compact", scale = 2):
predict_btn = gr.Button(
value="Predict peptide sequence from a protein sequence",
interactive=True,
variant="primary",
)
plot_struc_btn = gr.Button(value = "Plot ESMFold predicted structure ", variant="primary")
predict_file_btn = gr.Button(
value="Predict peptide from a local file",
interactive=True,
variant="primary",
)
with gr.Row():
with gr.Column(variant="compact", scale = 5):
output_text = gr.Textbox(
lines=1,
max_lines=12,
label="Output",
placeholder="Output",
)
with gr.Column(variant="compact", scale = 5):
finetune_button = gr.Button(
value="Finetune Pre-trained Model",
interactive=True,
variant="primary",
)
with gr.Row():
output_viewer = gr.HTML()
output_file = gr.File(
label="Download as Text File",
file_count="single",
type="filepath",
interactive=False,
)
# select protein sample
name.change(fn=suggest, inputs=name, outputs=input_seq)
# "Predict peptide sequence" actions
predict_btn.click(
fn = predict_peptide,
inputs=[base_model_name,PEFT_model_name,input_seq,peptide_length,num_pred_peptides,plddt_iptm_yes],
outputs = [output_text, input_seq],
)
# "Predict peptide from a local file" actions
predict_file_btn.click(
fn = predict_peptide_from_file,
inputs=[base_model_name,PEFT_model_name,uploaded_file,peptide_length,num_pred_peptides,plddt_iptm_yes],
outputs = [output_file],
)
# "Finetune Pre-trained Model" actions
finetune_button.click(
fn = finetune,
inputs=[base_model_name,peptide_length],
outputs = [output_text],
)
# plot protein structure
plot_struc_btn.click(fn=plot_struc, inputs=input_seq, outputs=[output_file, output_viewer])
demo.launch() |