import gradio as gr import joblib import torch import numpy as np from transformers import AutoTokenizer, AutoModel, EsmTokenizer, EsmModel from rdkit import Chem device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the pretrained models esm = "facebook/esm2_t12_35M_UR50D" # generate protein embeddings esm_tokenizer = EsmTokenizer.from_pretrained(esm) esm_model = EsmModel.from_pretrained(esm).to(device) chemberta = "DeepChem/ChemBERTa-10M-MTR" # generate ligand embeddings chemberta_tokenizer = AutoTokenizer.from_pretrained(chemberta) chemberta_model = AutoModel.from_pretrained(chemberta).to(device) scaler = joblib.load("scaler.pkl") pca = joblib.load("pca.pkl") svr = joblib.load("svr_model.pkl") def generate_protein_embedding(protein_input, input_type): # Generate FASTA string from file if input_type == "File": mol = Chem.MolFromPDBFile(protein_input) if not mol: return None fasta = Chem.MolToFASTA(mol).splitlines()[1] # The input was FASTA else: fasta = protein_input.strip() # Generate protein embedding esm_model.eval() with torch.no_grad(): inputs = esm_tokenizer(fasta, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device) outputs = esm_model(**inputs) embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() return embedding def generate_ligand_embedding(ligand_input, input_type): if input_type == "File": mol = Chem.MolFromMol2File(ligand_input) if not mol: return None smiles = Chem.MolToSmiles(mol) else: smiles = ligand_input.strip() # Generate compounds embeddings from SMILES chemberta_model.eval() with torch.no_grad(): inputs = chemberta_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True).to(device) outputs = chemberta_model(**inputs) embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() return embedding # Convert -logKd predicted value to Kd def value_conversion(logKa): logKd = logKa * -1 Kd = 10 ** (logKd) if Kd >= 1e-3: return f"{Kd * 1e3:.4f} mM" # Millimolar elif Kd >= 1e-6: return f"{Kd * 1e6:.4f} µM" # Micromolar elif Kd >= 1e-9: return f"{Kd * 1e9:.4f} nM" # Nanomolar else: return f"{Kd * 1e12:.4f} pM" # Picomolar def predict_affinity(protein_file, protein_fasta, protein_type, ligand_file, ligand_smiles, ligand_type): # Determine protein input if protein_file is not None: protein_input = protein_file protein_type = "File" elif protein_fasta is not None: protein_input = protein_fasta.strip() protein_type = "FASTA" else: return "Error: No valid protein input provided." # Determine ligand input if ligand_file is not None: ligand_input = ligand_file ligand_type = "File" elif ligand_smiles is not None: ligand_input = ligand_smiles.strip() ligand_type = "SMILES" else: return "Error: No valid ligand input provided." # Get embeddings protein = generate_protein_embedding(protein_input, protein_type) ligand = generate_ligand_embedding(ligand_input, ligand_type) if protein is None: return "Unable to parse protein .pdb file" if ligand is None: return "Unable to parse ligand .pdb file" embedding = np.concatenate((protein, ligand), axis=1) # Apply scaling and PCA svr_input = scaler.transform(embedding) svr_input = pca.transform(svr_input) # Predict the log binding affinity log_prediction = svr.predict(svr_input)[0] affinity_value = value_conversion(log_prediction) return f"Predicted Binding Affinity:\nlogKa = {log_prediction:.4f}\nKd = {affinity_value}" def update_inputs(protein_type, ligand_type): # Updates visibility and interactivity dynamically return ( gr.update(visible=(protein_type == "File"), interactive=(protein_type == "File")), gr.update(visible=(protein_type == "FASTA"), interactive=(protein_type == "FASTA")), gr.update(visible=(ligand_type == "File"), interactive=(ligand_type == "File")), gr.update(visible=(ligand_type == "SMILES"), interactive=(ligand_type == "SMILES")) ) with gr.Blocks() as iface: gr.Markdown("# Predict Protein-Ligand Binding Affinity") gr.Markdown("Upload protein and compound files or enter FASTA/SMILES strings to predict binding affinity.") with gr.Row(): protein_type = gr.Radio(["File", "FASTA"], label="Protein Input Type", value="File") ligand_type = gr.Radio(["File", "SMILES"], label="Ligand Input Type", value="File") protein_file = gr.File(label="Protein .pdb file", visible=True, interactive=True) protein_fasta = gr.Textbox(label="Protein FASTA sequence", visible=False, interactive=True) ligand_file = gr.File(label="Ligand .mol2 file", visible=True, interactive=True) ligand_smiles = gr.Textbox(label="Ligand SMILES string", visible=False, interactive=True) output = gr.Textbox(label="Prediction Result", lines=3) submit_btn = gr.Button("Predict") submit_btn.click( predict_affinity, inputs=[protein_file, protein_fasta, protein_type, ligand_file, ligand_smiles, ligand_type], outputs=output ) protein_type.change(update_inputs, inputs=[protein_type, ligand_type], outputs=[protein_file, protein_fasta, ligand_file, ligand_smiles]) ligand_type.change(update_inputs, inputs=[protein_type, ligand_type], outputs=[protein_file, protein_fasta, ligand_file, ligand_smiles]) iface.launch()