ESM2PPI / app.py
wangjin2000's picture
Update app.py
b594185 verified
#ref: https://huggingface.co/blog/AmelieSchreiber/esm-interact
import gradio as gr
import os
from transformers import Trainer, TrainingArguments, AutoTokenizer, EsmForMaskedLM, AutoModelForMaskedLM, TrainerCallback, EsmForProteinFolding
import torch
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.optim import AdamW
from torch.distributions import Categorical
import pandas as pd
from pathlib import Path
#import wandb
import numpy as np
from datetime import datetime
import time
from plot_pdb import plot_struc
import requests
import Bio.PDB
# Constants & Globals
HF_TOKEN = os.environ.get("HF_token")
print("HF_TOKEN:",HF_TOKEN)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_OPTIONS = [
"facebook/esm2_t6_8M_UR50D",
"facebook/esm2_t12_35M_UR50D",
"facebook/esm2_t33_650M_UR50D",
] # models users can choose from
PEFT_MODEL_OPTIONS = [
"wangjin2000/esm2_t6_8M_UR50D_PPI_2024-09-19_06",
"wangjin2000/esm2_t12_35M_UR50D_PPI_2024-09-19_09",
"wangjin2000/esm2_t33_650M_UR50D_PPI_2024-09-19_10",
"wangjin2000/esm2_t6_8M_UR50D_PPI_2024-11-19_04",
"facebook/esm2_t6_8M_UR50D",
"facebook/esm2_t12_35M_UR50D",
"facebook/esm2_t33_650M_UR50D",
] # finetuned models, but the last three models are original fundation models
#build datasets
class ProteinDataset(Dataset):
def __init__(self, file, tokenizer, peptide_length):
data = pd.read_csv(file)
self.tokenizer = tokenizer
#self.proteins = data["Receptor Sequence"].tolist()
#self.peptides = data["Binder"].tolist()
self.proteins = data["P_Sequence"].tolist() #header defined by Lin Qiao
self.peptides = data["p_Sequence"].tolist()
self.max_length_pm = 1000 + 2 + peptide_length #assume the max length of protein is 1000
def __len__(self):
return len(self.proteins)
def __getitem__(self, idx):
protein_seq = self.proteins[idx]
peptide_seq = self.peptides[idx]
masked_peptide = '<mask>' * len(peptide_seq)
complex_seq = protein_seq + masked_peptide
# Tokenize and pad the complex sequence
complex_input = self.tokenizer(complex_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True).to(device)
input_ids = complex_input["input_ids"].squeeze()
attention_mask = complex_input["attention_mask"].squeeze()
# Create labels (tokens for ground truth AAs)
label_seq = protein_seq + peptide_seq
labels = self.tokenizer(label_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True)["input_ids"].to(device).squeeze()
# Set non-masked positions in the labels tensor to -100
labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
# fine-tuning function
def finetune(base_model_path, peptide_length): #, train_dataset, test_dataset):
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#load base model
base_model = EsmForMaskedLM.from_pretrained(base_model_path).to(device)
# Tokenization
tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D")
#train_dataset = ProteinDataset("./datasets/pepnn_train.csv", tokenizer, peptide_length)
#test_dataset = ProteinDataset("./datasets/pepnn_test.csv", tokenizer, peptide_length)
train_dataset = ProteinDataset("./datasets/peptide-protein-train.csv", tokenizer, peptide_length)
eval_dataset = ProteinDataset("./datasets/peptide-protein-eval.csv", tokenizer, peptide_length)
model_name_base = base_model_path.split("/")[1]
timestamp = datetime.now().strftime('%Y-%m-%d_%H')
lr = 0.0007984276816171436
save_path = f"{model_name_base}_PPI_{timestamp}"
training_args = TrainingArguments(
output_dir=save_path, #f"{model_name_base}_PPI_{timestamp}",
num_train_epochs = 5,
per_device_train_batch_size = 2,
per_device_eval_batch_size = 16,
warmup_steps = 501,
logging_dir=None,
logging_steps=10,
evaluation_strategy="epoch",
load_best_model_at_end=True,
save_strategy='epoch',
metric_for_best_model='eval_loss',
save_total_limit = 5,
gradient_accumulation_steps=2,
push_to_hub=True, #jw 20240918
hub_token = HF_TOKEN, #jw 20240918
dataloader_pin_memory=False, #jw 20241119 true for CPU
)
# Initialize Trainer
trainer = Trainer(
model=base_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
optimizers=(AdamW(base_model.parameters(), lr=lr), None),
)
# Train and Save Model
trainer.train()
return save_path
def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq):
start = time.time()
sequence = protein_seq + binder_seq
print("129:model.device in compute_pseudo_perplexity",model.device)
original_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device)
length_of_binder = len(binder_seq)
# Prepare a batch with each row having one masked token from the binder sequence
masked_inputs = original_input.repeat(length_of_binder, 1)
positions_to_mask = torch.arange(-length_of_binder - 1, -1, device=model.device)
masked_inputs[torch.arange(length_of_binder), positions_to_mask] = tokenizer.mask_token_id
# Prepare labels for the masked tokens
labels = torch.full_like(masked_inputs, -100)
labels[torch.arange(length_of_binder), positions_to_mask] = original_input[0, positions_to_mask]
# Get model predictions and calculate loss
with torch.no_grad():
outputs = model(masked_inputs, labels=labels)
loss = outputs.loss
# Loss is already averaged by the model
avg_loss = loss.item()
pseudo_perplexity = np.exp(avg_loss)
end = time.time()
elapsed = end - start
#print(f'compute_pseudo_perplexity time: {elapsed:.4f} seconds')
return pseudo_perplexity
# compute pLLDT and iPMT from ESMFOLD model directly, very slow
def compute_plddt_iptm(protein_seq, binder_seq):
start = time.time()
# always the ESMFold model
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").to(device)
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model.eval()
# based on https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/protein_folding.ipynb
#model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)
#model = model.cuda()
#model.esm = model.esm.half() # Uncomment to switch the stem to float16
#torch.backends.cuda.matmul.allow_tf32 = True
print("168:model.device in compute_plddt_iptm",model.device)
sequence = protein_seq + binder_seq
inputs = tokenizer(sequence, return_tensors='pt', add_special_tokens=False).to(device)
# Get model predictions
with torch.no_grad():
outputs = model(**inputs)
plddt = outputs.plddt
ptm = outputs.ptm.item()
avg_plddt = plddt[0,:,1].mean().item()
#iPTM = ptm
#print("170: iPTM:",iPTM)
end = time.time()
elapsed = end - start
#print(f'compute_plddt_iptm time: {elapsed:.4f} seconds')
return avg_plddt, ptm
# call web API of ESMFOLD to get pLLDT
def get_plddt(protein_seq, binder_seq):
start = time.time()
sequence = protein_seq + binder_seq
retries = 0
pdb_str = None
url = "https://api.esmatlas.com/foldSequence/v1/pdb/"
while retries < 3 and pdb_str is None:
response = requests.post(url, data=sequence, verify=False)
pdb_str = response.text
if pdb_str == "INTERNAL SERVER ERROR":
retries += 1
time.sleep(0.1)
pdb = None #pdb = str = None
#save a pdb format file
name = sequence[:3] + sequence[-3:] #combine the firt and last 3 AAs of sequence as a filename.
outpath = (
Path.cwd() / f"PDB-{name}.pdb")
with open(outpath.name, "w") as f:
f.write(pdb_str)
outpath_str = str(outpath)
#get pdb column values
p = Bio.PDB.PDBParser()
structure = p.get_structure('myStructureName', outpath_str)
ids = [a.get_id() for a in structure.get_atoms()]
pLDDTs = [a.get_bfactor() for a in structure.get_atoms()]
avg_plddt = np.mean(pLDDTs)
ptm = 0 #place holder for iPTM
end = time.time()
elapsed = end - start
print(f'get_plddt time: {elapsed:.4f} seconds')
return avg_plddt, ptm
def generate_peptide_for_single_sequence(model, tokenizer, protein_seq, peptide_length = 15, top_k = 3, num_binders = 5, plddt_iptm_yes="no"):
start = time.time()
peptide_length = int(peptide_length)
top_k = int(top_k)
num_binders = int(num_binders)
binders_with_ppl_plddt = []
n = 0
for _ in range(num_binders):
n += 1
#print("n in num_binders:", n)
# Generate binder
masked_peptide = '<mask>' * peptide_length
input_sequence = protein_seq + masked_peptide
inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device)
#inputs = tokenizer(input_sequence, return_tensors="pt").to(device)
print("198:model.device in generate_:",model.device)
with torch.no_grad():
logits = model(**inputs).logits
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
logits_at_masks = logits[0, mask_token_indices]
# Apply top-k sampling
top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1)
probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1)
predicted_indices = Categorical(probabilities).sample()
predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)
generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '')
# Compute PPL for the generated binder
ppl = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder)
# Get PLDDT from ESMFold model
if plddt_iptm_yes=="yes":
plddt, iptm = get_plddt(protein_seq, generated_binder)
#plddt, iptm = compute_plddt_iptm(protein_seq, generated_binder) #too time-consuming
else:
plddt, iptm = [0, 0]
# Add the generated binder and its PPL to the results list
binders_with_ppl_plddt.append([generated_binder, ppl, plddt, iptm])
end = time.time()
elapsed = end - start
print(f'generate_peptide_for_single_sequence: {elapsed:.4f} seconds')
return binders_with_ppl_plddt
# Predict peptide binder with finetuned model
def predict_peptide(base_model_path, finetuned_model_path, input_seqs, peptide_length=15, num_binders=4, plddt_iptm_yes="no"):
# Load the model
loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path) #.to(device) inference use cpu
# Ensure the model is in evaluation mode
loaded_model.eval()
# Tokenization
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# set top_k mutations for each AA position
top_k=3
if isinstance(input_seqs, str): # Single sequence
binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, input_seqs, peptide_length, top_k, num_binders, plddt_iptm_yes)
results_df = pd.DataFrame(binders, columns=['Binder', 'PPL', 'pLDDT', 'iPTM'])
elif isinstance(input_seqs, list): # List of sequences
results = []
for seq in input_seqs:
binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, seq, peptide_length, top_k, num_binders, plddt_iptm_yes)
for binder, ppl, plddt, iptm in binders:
results.append([seq, binder, ppl, plddt, iptm])
results_df = pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'PPL', 'pLDDT', 'iPTM'])
print(results_df)
#combine target protein and predicted peptide with 20 G amino acids.
separator = 'G' * 20
peptide_lp = results_df['Binder'][results_df['PPL'].idxmin()] #Choosing the one with the lowest perplexity
PPC = input_seqs + separator + peptide_lp
return results_df, PPC
def predict_peptide_from_file(base_model_path, finetuned_model_path, file_obj, max_peptide_length=15, num_binders=5, plddt_iptm_yes="no"):
start = time.time()
# Load the model
loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path) #.to(device)
# Ensure the model is in evaluation mode
loaded_model.eval()
# Tokenization
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
input = pd.read_csv(file_obj, header=0 )
results = []
for i, row in input.iterrows():
print("sequence:", i)
#protein_seq = row['Receptor Sequence']
#peptide_seq = row['Peptide Sequence']
protein_seq = row['P_Sequence']
peptide_seq = row['p_Sequence']
peptide_length = min([len(peptide_seq), max_peptide_length]) # use the same length of ground truth peptide length for prediction limited to max_peptide_length
#get metrics for ground truth peptide
ppl = compute_pseudo_perplexity(loaded_model, tokenizer, protein_seq, peptide_seq)
if plddt_iptm_yes=="yes":
plddt, iptm = get_plddt(protein_seq, peptide_seq)
#plddt, iptm = compute_plddt_iptm(protein_seq, peptide_seq) #too time-consuming
else:
plddt, iptm = [0, 0]
results.append([protein_seq, peptide_seq, ppl, plddt, iptm, 1]) # flag 1 for ground truth peptide
# set top_k mutations for each AA position
top_k=3
#predict peptides
binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, protein_seq, peptide_length, top_k, num_binders, plddt_iptm_yes)
for binder, ppl, plddt, iptm in binders:
results.append([protein_seq, binder, ppl, plddt, iptm, 0]) # flag 0 for generated peptide
results_df = pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'PPL', 'pLDDT', 'iPTM', 'GT_Flag'])
timestamp = datetime.now().strftime('%Y-%m-%d_%H')
outpath = (
Path.cwd() / f"predicted_peptides_{timestamp}.csv"
)
str_outpath = str(outpath) # work around as the latest gr.File needs a string_type input.
results_df.to_csv(str_outpath, header=True, index=False)
end = time.time()
elapsed = end - start
print(f'predict_peptide_from_file: {elapsed:.4f} seconds')
return str_outpath
def suggest(option):
if option == "Protein:P63279":
suggestion = "MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS"
elif option == "Default protein":
#suggestion = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
suggestion = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"
elif option == "Antifreeze protein":
suggestion = "QCTGGADCTSCTGACTGCGNCPNAVTCTNSQHCVKANTCTGSTDCNTAQTCTNSKDCFEANTCTDSTNCYKATACTNSSGCPGH"
elif option == "AI Generated protein":
suggestion = "MSGMKKLYEYTVTTLDEFLEKLKEFILNTSKDKIYKLTITNPKLIKDIGKAIAKAAEIADVDPKEIEEMIKAVEENELTKLVITIEQTDDKYVIKVELENEDGLVHSFEIYFKNKEEMEKFLELLEKLISKLSGS"
elif option == "7-bladed propeller fold":
suggestion = "VKLAGNSSLCPINGWAVYSKDNSIRIGSKGDVFVIREPFISCSHLECRTFFLTQGALLNDKHSNGTVKDRSPHRTLMSCPVGEAPSPYNSRFESVAWSASACHDGTSWLTIGISGPDNGAVAVLKYNGIITDTIKSWRNNILRTQESECACVNGSCFTVMTDGPSNGQASYKIFKMEKGKVVKSVELDAPNYHYEECSCYPNAGEITCVCRDNWHGSNRPWVSFNQNLEYQIGYICSGVFGDNPRPNDGTGSCGPVSSNGAYGVKGFSFKYGNGVWIGRTKSTNSRSGFEMIWDPNGWTETDSSFSVKQDIVAITDWSGYSGSFVQHPELTGLDCIRPCFWVELIRGRPKESTIWTSGSSISFCGVNSDTVGWSWPDGAELPFTIDK"
else:
suggestion = ""
return suggestion
demo = gr.Blocks(title="ESM2 for Micropertide Protein Interaction (ESM2MPI)")
with demo:
gr.Markdown("# ESM2 for Micropertide Protein Interaction (ESM2MPI)")
#gr.Textbox(dubug_result)
with gr.Column():
gr.Markdown("## Select a base model and a corresponding finetuned model")
with gr.Row():
with gr.Column(scale=5, variant="compact"):
base_model_name = gr.Dropdown(
choices=MODEL_OPTIONS,
value=MODEL_OPTIONS[0],
label="Base Model Name",
interactive = True,
)
PEFT_model_name = gr.Dropdown(
choices=PEFT_MODEL_OPTIONS,
value=PEFT_MODEL_OPTIONS[0],
label="Finetuned Model Name",
interactive = True,
)
with gr.Row():
peptide_length=gr.Slider(minimum=10, maximum=100, step=1, label="Peptide Maximum Length", value=15)
num_pred_peptides=gr.Slider(minimum=1, maximum=10, step=1, label="Number of Predicted Peptides", value=5)
plddt_iptm_yes=gr.Radio(["yes", "no"],label="Compute pLDDT and iPTM (slow!)", value="no")
with gr.Column(scale=5, variant="compact"):
name = gr.Dropdown(
label="Choose a Sample Protein",
value="Protein:P63279",
choices=["Default protein", "Antifreeze protein", "Protein:P63279", "AI Generated protein", "7-bladed propeller fold", "custom"]
)
uploaded_file = gr.File(
label="Local File Upload",
file_count="single",
file_types=[".tsv", ".csv"],
type="filepath",
height=40,
)
gr.Markdown(
"## Predict peptide sequence:"
)
with gr.Row():
with gr.Column(variant="compact", scale = 8):
input_seq = gr.Textbox(
lines=1,
max_lines=12,
label="Protein:P63279 to be predicted:",
value="MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS",
placeholder="Paste your protein sequence here...",
interactive = True,
)
text_pos = gr.Textbox(
lines=1,
max_lines=12,
label="Sequency Position:",
placeholder=
"012345678911234567892123456789312345678941234567895123456789612345678971234567898123456789912345678901234567891123456789",
interactive=False,
)
with gr.Column(variant="compact", scale = 2):
predict_btn = gr.Button(
value="Predict peptide sequence from a protein sequence",
interactive=True,
variant="primary",
)
plot_struc_btn = gr.Button(value = "Plot ESMFold predicted structure ", variant="primary")
predict_file_btn = gr.Button(
value="Predict peptide from a local file",
interactive=True,
variant="primary",
)
with gr.Row():
with gr.Column(variant="compact", scale = 5):
output_text = gr.Textbox(
lines=1,
max_lines=12,
label="Output",
placeholder="Output",
)
with gr.Column(variant="compact", scale = 5):
finetune_button = gr.Button(
value="Finetune Pre-trained Model",
interactive=True,
variant="primary",
)
with gr.Row():
output_viewer = gr.HTML()
output_file = gr.File(
label="Download as Text File",
file_count="single",
type="filepath",
interactive=False,
)
# select protein sample
name.change(fn=suggest, inputs=name, outputs=input_seq)
# "Predict peptide sequence" actions
predict_btn.click(
fn = predict_peptide,
inputs=[base_model_name,PEFT_model_name,input_seq,peptide_length,num_pred_peptides,plddt_iptm_yes],
outputs = [output_text, input_seq],
)
# "Predict peptide from a local file" actions
predict_file_btn.click(
fn = predict_peptide_from_file,
inputs=[base_model_name,PEFT_model_name,uploaded_file,peptide_length,num_pred_peptides,plddt_iptm_yes],
outputs = [output_file],
)
# "Finetune Pre-trained Model" actions
finetune_button.click(
fn = finetune,
inputs=[base_model_name,peptide_length],
outputs = [output_text],
)
# plot protein structure
plot_struc_btn.click(fn=plot_struc, inputs=input_seq, outputs=[output_file, output_viewer])
demo.launch()