#ref: https://huggingface.co/blog/AmelieSchreiber/esm-interact import gradio as gr import os from transformers import Trainer, TrainingArguments, AutoTokenizer, EsmForMaskedLM, AutoModelForMaskedLM, TrainerCallback, EsmForProteinFolding import torch from torch.utils.data import DataLoader, Dataset, RandomSampler from torch.optim import AdamW from torch.distributions import Categorical import pandas as pd from pathlib import Path #import wandb import numpy as np from datetime import datetime import time from plot_pdb import plot_struc import requests import Bio.PDB # Constants & Globals HF_TOKEN = os.environ.get("HF_token") print("HF_TOKEN:",HF_TOKEN) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") MODEL_OPTIONS = [ "facebook/esm2_t6_8M_UR50D", "facebook/esm2_t12_35M_UR50D", "facebook/esm2_t33_650M_UR50D", ] # models users can choose from PEFT_MODEL_OPTIONS = [ "wangjin2000/esm2_t6_8M_UR50D_PPI_2024-09-19_06", "wangjin2000/esm2_t12_35M_UR50D_PPI_2024-09-19_09", "wangjin2000/esm2_t33_650M_UR50D_PPI_2024-09-19_10", "wangjin2000/esm2_t6_8M_UR50D_PPI_2024-11-19_04", "facebook/esm2_t6_8M_UR50D", "facebook/esm2_t12_35M_UR50D", "facebook/esm2_t33_650M_UR50D", ] # finetuned models, but the last three models are original fundation models #build datasets class ProteinDataset(Dataset): def __init__(self, file, tokenizer, peptide_length): data = pd.read_csv(file) self.tokenizer = tokenizer #self.proteins = data["Receptor Sequence"].tolist() #self.peptides = data["Binder"].tolist() self.proteins = data["P_Sequence"].tolist() #header defined by Lin Qiao self.peptides = data["p_Sequence"].tolist() self.max_length_pm = 1000 + 2 + peptide_length #assume the max length of protein is 1000 def __len__(self): return len(self.proteins) def __getitem__(self, idx): protein_seq = self.proteins[idx] peptide_seq = self.peptides[idx] masked_peptide = '' * len(peptide_seq) complex_seq = protein_seq + masked_peptide # Tokenize and pad the complex sequence complex_input = self.tokenizer(complex_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True).to(device) input_ids = complex_input["input_ids"].squeeze() attention_mask = complex_input["attention_mask"].squeeze() # Create labels (tokens for ground truth AAs) label_seq = protein_seq + peptide_seq labels = self.tokenizer(label_seq, return_tensors="pt", padding="max_length", max_length = self.max_length_pm, truncation=True)["input_ids"].to(device).squeeze() # Set non-masked positions in the labels tensor to -100 labels = torch.where(input_ids == self.tokenizer.mask_token_id, labels, -100) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} # fine-tuning function def finetune(base_model_path, peptide_length): #, train_dataset, test_dataset): #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #load base model base_model = EsmForMaskedLM.from_pretrained(base_model_path).to(device) # Tokenization tokenizer = AutoTokenizer.from_pretrained(base_model_path) #("facebook/esm2_t12_35M_UR50D") #train_dataset = ProteinDataset("./datasets/pepnn_train.csv", tokenizer, peptide_length) #test_dataset = ProteinDataset("./datasets/pepnn_test.csv", tokenizer, peptide_length) train_dataset = ProteinDataset("./datasets/peptide-protein-train.csv", tokenizer, peptide_length) eval_dataset = ProteinDataset("./datasets/peptide-protein-eval.csv", tokenizer, peptide_length) model_name_base = base_model_path.split("/")[1] timestamp = datetime.now().strftime('%Y-%m-%d_%H') lr = 0.0007984276816171436 save_path = f"{model_name_base}_PPI_{timestamp}" training_args = TrainingArguments( output_dir=save_path, #f"{model_name_base}_PPI_{timestamp}", num_train_epochs = 5, per_device_train_batch_size = 2, per_device_eval_batch_size = 16, warmup_steps = 501, logging_dir=None, logging_steps=10, evaluation_strategy="epoch", load_best_model_at_end=True, save_strategy='epoch', metric_for_best_model='eval_loss', save_total_limit = 5, gradient_accumulation_steps=2, push_to_hub=True, #jw 20240918 hub_token = HF_TOKEN, #jw 20240918 dataloader_pin_memory=False, #jw 20241119 true for CPU ) # Initialize Trainer trainer = Trainer( model=base_model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, optimizers=(AdamW(base_model.parameters(), lr=lr), None), ) # Train and Save Model trainer.train() return save_path def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq): start = time.time() sequence = protein_seq + binder_seq print("129:model.device in compute_pseudo_perplexity",model.device) original_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device) length_of_binder = len(binder_seq) # Prepare a batch with each row having one masked token from the binder sequence masked_inputs = original_input.repeat(length_of_binder, 1) positions_to_mask = torch.arange(-length_of_binder - 1, -1, device=model.device) masked_inputs[torch.arange(length_of_binder), positions_to_mask] = tokenizer.mask_token_id # Prepare labels for the masked tokens labels = torch.full_like(masked_inputs, -100) labels[torch.arange(length_of_binder), positions_to_mask] = original_input[0, positions_to_mask] # Get model predictions and calculate loss with torch.no_grad(): outputs = model(masked_inputs, labels=labels) loss = outputs.loss # Loss is already averaged by the model avg_loss = loss.item() pseudo_perplexity = np.exp(avg_loss) end = time.time() elapsed = end - start #print(f'compute_pseudo_perplexity time: {elapsed:.4f} seconds') return pseudo_perplexity # compute pLLDT and iPMT from ESMFOLD model directly, very slow def compute_plddt_iptm(protein_seq, binder_seq): start = time.time() # always the ESMFold model model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1").to(device) tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") model.eval() # based on https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/protein_folding.ipynb #model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True) #model = model.cuda() #model.esm = model.esm.half() # Uncomment to switch the stem to float16 #torch.backends.cuda.matmul.allow_tf32 = True print("168:model.device in compute_plddt_iptm",model.device) sequence = protein_seq + binder_seq inputs = tokenizer(sequence, return_tensors='pt', add_special_tokens=False).to(device) # Get model predictions with torch.no_grad(): outputs = model(**inputs) plddt = outputs.plddt ptm = outputs.ptm.item() avg_plddt = plddt[0,:,1].mean().item() #iPTM = ptm #print("170: iPTM:",iPTM) end = time.time() elapsed = end - start #print(f'compute_plddt_iptm time: {elapsed:.4f} seconds') return avg_plddt, ptm # call web API of ESMFOLD to get pLLDT def get_plddt(protein_seq, binder_seq): start = time.time() sequence = protein_seq + binder_seq retries = 0 pdb_str = None url = "https://api.esmatlas.com/foldSequence/v1/pdb/" while retries < 3 and pdb_str is None: response = requests.post(url, data=sequence, verify=False) pdb_str = response.text if pdb_str == "INTERNAL SERVER ERROR": retries += 1 time.sleep(0.1) pdb = None #pdb = str = None #save a pdb format file name = sequence[:3] + sequence[-3:] #combine the firt and last 3 AAs of sequence as a filename. outpath = ( Path.cwd() / f"PDB-{name}.pdb") with open(outpath.name, "w") as f: f.write(pdb_str) outpath_str = str(outpath) #get pdb column values p = Bio.PDB.PDBParser() structure = p.get_structure('myStructureName', outpath_str) ids = [a.get_id() for a in structure.get_atoms()] pLDDTs = [a.get_bfactor() for a in structure.get_atoms()] avg_plddt = np.mean(pLDDTs) ptm = 0 #place holder for iPTM end = time.time() elapsed = end - start print(f'get_plddt time: {elapsed:.4f} seconds') return avg_plddt, ptm def generate_peptide_for_single_sequence(model, tokenizer, protein_seq, peptide_length = 15, top_k = 3, num_binders = 5, plddt_iptm_yes="no"): start = time.time() peptide_length = int(peptide_length) top_k = int(top_k) num_binders = int(num_binders) binders_with_ppl_plddt = [] n = 0 for _ in range(num_binders): n += 1 #print("n in num_binders:", n) # Generate binder masked_peptide = '' * peptide_length input_sequence = protein_seq + masked_peptide inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device) #inputs = tokenizer(input_sequence, return_tensors="pt").to(device) print("198:model.device in generate_:",model.device) with torch.no_grad(): logits = model(**inputs).logits mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1] logits_at_masks = logits[0, mask_token_indices] # Apply top-k sampling top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1) probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1) predicted_indices = Categorical(probabilities).sample() predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1) generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '') # Compute PPL for the generated binder ppl = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder) # Get PLDDT from ESMFold model if plddt_iptm_yes=="yes": plddt, iptm = get_plddt(protein_seq, generated_binder) #plddt, iptm = compute_plddt_iptm(protein_seq, generated_binder) #too time-consuming else: plddt, iptm = [0, 0] # Add the generated binder and its PPL to the results list binders_with_ppl_plddt.append([generated_binder, ppl, plddt, iptm]) end = time.time() elapsed = end - start print(f'generate_peptide_for_single_sequence: {elapsed:.4f} seconds') return binders_with_ppl_plddt # Predict peptide binder with finetuned model def predict_peptide(base_model_path, finetuned_model_path, input_seqs, peptide_length=15, num_binders=4, plddt_iptm_yes="no"): # Load the model loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path) #.to(device) inference use cpu # Ensure the model is in evaluation mode loaded_model.eval() # Tokenization tokenizer = AutoTokenizer.from_pretrained(base_model_path) # set top_k mutations for each AA position top_k=3 if isinstance(input_seqs, str): # Single sequence binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, input_seqs, peptide_length, top_k, num_binders, plddt_iptm_yes) results_df = pd.DataFrame(binders, columns=['Binder', 'PPL', 'pLDDT', 'iPTM']) elif isinstance(input_seqs, list): # List of sequences results = [] for seq in input_seqs: binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, seq, peptide_length, top_k, num_binders, plddt_iptm_yes) for binder, ppl, plddt, iptm in binders: results.append([seq, binder, ppl, plddt, iptm]) results_df = pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'PPL', 'pLDDT', 'iPTM']) print(results_df) #combine target protein and predicted peptide with 20 G amino acids. separator = 'G' * 20 peptide_lp = results_df['Binder'][results_df['PPL'].idxmin()] #Choosing the one with the lowest perplexity PPC = input_seqs + separator + peptide_lp return results_df, PPC def predict_peptide_from_file(base_model_path, finetuned_model_path, file_obj, max_peptide_length=15, num_binders=5, plddt_iptm_yes="no"): start = time.time() # Load the model loaded_model = AutoModelForMaskedLM.from_pretrained(finetuned_model_path) #.to(device) # Ensure the model is in evaluation mode loaded_model.eval() # Tokenization tokenizer = AutoTokenizer.from_pretrained(base_model_path) input = pd.read_csv(file_obj, header=0 ) results = [] for i, row in input.iterrows(): print("sequence:", i) #protein_seq = row['Receptor Sequence'] #peptide_seq = row['Peptide Sequence'] protein_seq = row['P_Sequence'] peptide_seq = row['p_Sequence'] peptide_length = min([len(peptide_seq), max_peptide_length]) # use the same length of ground truth peptide length for prediction limited to max_peptide_length #get metrics for ground truth peptide ppl = compute_pseudo_perplexity(loaded_model, tokenizer, protein_seq, peptide_seq) if plddt_iptm_yes=="yes": plddt, iptm = get_plddt(protein_seq, peptide_seq) #plddt, iptm = compute_plddt_iptm(protein_seq, peptide_seq) #too time-consuming else: plddt, iptm = [0, 0] results.append([protein_seq, peptide_seq, ppl, plddt, iptm, 1]) # flag 1 for ground truth peptide # set top_k mutations for each AA position top_k=3 #predict peptides binders = generate_peptide_for_single_sequence(loaded_model, tokenizer, protein_seq, peptide_length, top_k, num_binders, plddt_iptm_yes) for binder, ppl, plddt, iptm in binders: results.append([protein_seq, binder, ppl, plddt, iptm, 0]) # flag 0 for generated peptide results_df = pd.DataFrame(results, columns=['Input Sequence', 'Binder', 'PPL', 'pLDDT', 'iPTM', 'GT_Flag']) timestamp = datetime.now().strftime('%Y-%m-%d_%H') outpath = ( Path.cwd() / f"predicted_peptides_{timestamp}.csv" ) str_outpath = str(outpath) # work around as the latest gr.File needs a string_type input. results_df.to_csv(str_outpath, header=True, index=False) end = time.time() elapsed = end - start print(f'predict_peptide_from_file: {elapsed:.4f} seconds') return str_outpath def suggest(option): if option == "Protein:P63279": suggestion = "MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS" elif option == "Default protein": #suggestion = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE" suggestion = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" elif option == "Antifreeze protein": suggestion = "QCTGGADCTSCTGACTGCGNCPNAVTCTNSQHCVKANTCTGSTDCNTAQTCTNSKDCFEANTCTDSTNCYKATACTNSSGCPGH" elif option == "AI Generated protein": suggestion = "MSGMKKLYEYTVTTLDEFLEKLKEFILNTSKDKIYKLTITNPKLIKDIGKAIAKAAEIADVDPKEIEEMIKAVEENELTKLVITIEQTDDKYVIKVELENEDGLVHSFEIYFKNKEEMEKFLELLEKLISKLSGS" elif option == "7-bladed propeller fold": suggestion = "VKLAGNSSLCPINGWAVYSKDNSIRIGSKGDVFVIREPFISCSHLECRTFFLTQGALLNDKHSNGTVKDRSPHRTLMSCPVGEAPSPYNSRFESVAWSASACHDGTSWLTIGISGPDNGAVAVLKYNGIITDTIKSWRNNILRTQESECACVNGSCFTVMTDGPSNGQASYKIFKMEKGKVVKSVELDAPNYHYEECSCYPNAGEITCVCRDNWHGSNRPWVSFNQNLEYQIGYICSGVFGDNPRPNDGTGSCGPVSSNGAYGVKGFSFKYGNGVWIGRTKSTNSRSGFEMIWDPNGWTETDSSFSVKQDIVAITDWSGYSGSFVQHPELTGLDCIRPCFWVELIRGRPKESTIWTSGSSISFCGVNSDTVGWSWPDGAELPFTIDK" else: suggestion = "" return suggestion demo = gr.Blocks(title="ESM2 for Micropertide Protein Interaction (ESM2MPI)") with demo: gr.Markdown("# ESM2 for Micropertide Protein Interaction (ESM2MPI)") #gr.Textbox(dubug_result) with gr.Column(): gr.Markdown("## Select a base model and a corresponding finetuned model") with gr.Row(): with gr.Column(scale=5, variant="compact"): base_model_name = gr.Dropdown( choices=MODEL_OPTIONS, value=MODEL_OPTIONS[0], label="Base Model Name", interactive = True, ) PEFT_model_name = gr.Dropdown( choices=PEFT_MODEL_OPTIONS, value=PEFT_MODEL_OPTIONS[0], label="Finetuned Model Name", interactive = True, ) with gr.Row(): peptide_length=gr.Slider(minimum=10, maximum=100, step=1, label="Peptide Maximum Length", value=15) num_pred_peptides=gr.Slider(minimum=1, maximum=10, step=1, label="Number of Predicted Peptides", value=5) plddt_iptm_yes=gr.Radio(["yes", "no"],label="Compute pLDDT and iPTM (slow!)", value="no") with gr.Column(scale=5, variant="compact"): name = gr.Dropdown( label="Choose a Sample Protein", value="Protein:P63279", choices=["Default protein", "Antifreeze protein", "Protein:P63279", "AI Generated protein", "7-bladed propeller fold", "custom"] ) uploaded_file = gr.File( label="Local File Upload", file_count="single", file_types=[".tsv", ".csv"], type="filepath", height=40, ) gr.Markdown( "## Predict peptide sequence:" ) with gr.Row(): with gr.Column(variant="compact", scale = 8): input_seq = gr.Textbox( lines=1, max_lines=12, label="Protein:P63279 to be predicted:", value="MSGIALSRLAQERKAWRKDHPFGFVAVPTKNPDGTMNLMNWECAIPGKKGTPWEGGLFKLRMLFKDDYPSSPPKCKFEPPLFHPNVYPSGTVCLSILEEDKDWRPAITIKQILLGIQELLNEPNIQDPAQAEAYTIYCQNRVEYEKRVRAQAKKFAPS", placeholder="Paste your protein sequence here...", interactive = True, ) text_pos = gr.Textbox( lines=1, max_lines=12, label="Sequency Position:", placeholder= "012345678911234567892123456789312345678941234567895123456789612345678971234567898123456789912345678901234567891123456789", interactive=False, ) with gr.Column(variant="compact", scale = 2): predict_btn = gr.Button( value="Predict peptide sequence from a protein sequence", interactive=True, variant="primary", ) plot_struc_btn = gr.Button(value = "Plot ESMFold predicted structure ", variant="primary") predict_file_btn = gr.Button( value="Predict peptide from a local file", interactive=True, variant="primary", ) with gr.Row(): with gr.Column(variant="compact", scale = 5): output_text = gr.Textbox( lines=1, max_lines=12, label="Output", placeholder="Output", ) with gr.Column(variant="compact", scale = 5): finetune_button = gr.Button( value="Finetune Pre-trained Model", interactive=True, variant="primary", ) with gr.Row(): output_viewer = gr.HTML() output_file = gr.File( label="Download as Text File", file_count="single", type="filepath", interactive=False, ) # select protein sample name.change(fn=suggest, inputs=name, outputs=input_seq) # "Predict peptide sequence" actions predict_btn.click( fn = predict_peptide, inputs=[base_model_name,PEFT_model_name,input_seq,peptide_length,num_pred_peptides,plddt_iptm_yes], outputs = [output_text, input_seq], ) # "Predict peptide from a local file" actions predict_file_btn.click( fn = predict_peptide_from_file, inputs=[base_model_name,PEFT_model_name,uploaded_file,peptide_length,num_pred_peptides,plddt_iptm_yes], outputs = [output_file], ) # "Finetune Pre-trained Model" actions finetune_button.click( fn = finetune, inputs=[base_model_name,peptide_length], outputs = [output_text], ) # plot protein structure plot_struc_btn.click(fn=plot_struc, inputs=input_seq, outputs=[output_file, output_viewer]) demo.launch()