nninva's picture
Fixed interface with option for files or strings inputs
38fba51 verified
import gradio as gr
import joblib
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel, EsmTokenizer, EsmModel
from rdkit import Chem
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pretrained models
esm = "facebook/esm2_t12_35M_UR50D" # generate protein embeddings
esm_tokenizer = EsmTokenizer.from_pretrained(esm)
esm_model = EsmModel.from_pretrained(esm).to(device)
chemberta = "DeepChem/ChemBERTa-10M-MTR" # generate ligand embeddings
chemberta_tokenizer = AutoTokenizer.from_pretrained(chemberta)
chemberta_model = AutoModel.from_pretrained(chemberta).to(device)
scaler = joblib.load("scaler.pkl")
pca = joblib.load("pca.pkl")
svr = joblib.load("svr_model.pkl")
def generate_protein_embedding(protein_input, input_type):
# Generate FASTA string from file
if input_type == "File":
mol = Chem.MolFromPDBFile(protein_input)
if not mol:
return None
fasta = Chem.MolToFASTA(mol).splitlines()[1]
# The input was FASTA
else:
fasta = protein_input.strip()
# Generate protein embedding
esm_model.eval()
with torch.no_grad():
inputs = esm_tokenizer(fasta, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
outputs = esm_model(**inputs)
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
return embedding
def generate_ligand_embedding(ligand_input, input_type):
if input_type == "File":
mol = Chem.MolFromMol2File(ligand_input)
if not mol:
return None
smiles = Chem.MolToSmiles(mol)
else:
smiles = ligand_input.strip()
# Generate compounds embeddings from SMILES
chemberta_model.eval()
with torch.no_grad():
inputs = chemberta_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True).to(device)
outputs = chemberta_model(**inputs)
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
return embedding
# Convert -logKd predicted value to Kd
def value_conversion(logKa):
logKd = logKa * -1
Kd = 10 ** (logKd)
if Kd >= 1e-3:
return f"{Kd * 1e3:.4f} mM" # Millimolar
elif Kd >= 1e-6:
return f"{Kd * 1e6:.4f} µM" # Micromolar
elif Kd >= 1e-9:
return f"{Kd * 1e9:.4f} nM" # Nanomolar
else:
return f"{Kd * 1e12:.4f} pM" # Picomolar
def predict_affinity(protein_file, protein_fasta, protein_type, ligand_file, ligand_smiles, ligand_type):
# Determine protein input
if protein_file is not None:
protein_input = protein_file
protein_type = "File"
elif protein_fasta is not None:
protein_input = protein_fasta.strip()
protein_type = "FASTA"
else:
return "Error: No valid protein input provided."
# Determine ligand input
if ligand_file is not None:
ligand_input = ligand_file
ligand_type = "File"
elif ligand_smiles is not None:
ligand_input = ligand_smiles.strip()
ligand_type = "SMILES"
else:
return "Error: No valid ligand input provided."
# Get embeddings
protein = generate_protein_embedding(protein_input, protein_type)
ligand = generate_ligand_embedding(ligand_input, ligand_type)
if protein is None:
return "Unable to parse protein .pdb file"
if ligand is None:
return "Unable to parse ligand .pdb file"
embedding = np.concatenate((protein, ligand), axis=1)
# Apply scaling and PCA
svr_input = scaler.transform(embedding)
svr_input = pca.transform(svr_input)
# Predict the log binding affinity
log_prediction = svr.predict(svr_input)[0]
affinity_value = value_conversion(log_prediction)
return f"Predicted Binding Affinity:\nlogKa = {log_prediction:.4f}\nKd = {affinity_value}"
def update_inputs(protein_type, ligand_type):
# Updates visibility and interactivity dynamically
return (
gr.update(visible=(protein_type == "File"), interactive=(protein_type == "File")),
gr.update(visible=(protein_type == "FASTA"), interactive=(protein_type == "FASTA")),
gr.update(visible=(ligand_type == "File"), interactive=(ligand_type == "File")),
gr.update(visible=(ligand_type == "SMILES"), interactive=(ligand_type == "SMILES"))
)
with gr.Blocks() as iface:
gr.Markdown("# Predict Protein-Ligand Binding Affinity")
gr.Markdown("Upload protein and compound files or enter FASTA/SMILES strings to predict binding affinity.")
with gr.Row():
protein_type = gr.Radio(["File", "FASTA"], label="Protein Input Type", value="File")
ligand_type = gr.Radio(["File", "SMILES"], label="Ligand Input Type", value="File")
protein_file = gr.File(label="Protein .pdb file", visible=True, interactive=True)
protein_fasta = gr.Textbox(label="Protein FASTA sequence", visible=False, interactive=True)
ligand_file = gr.File(label="Ligand .mol2 file", visible=True, interactive=True)
ligand_smiles = gr.Textbox(label="Ligand SMILES string", visible=False, interactive=True)
output = gr.Textbox(label="Prediction Result", lines=3)
submit_btn = gr.Button("Predict")
submit_btn.click(
predict_affinity,
inputs=[protein_file, protein_fasta, protein_type, ligand_file, ligand_smiles, ligand_type],
outputs=output
)
protein_type.change(update_inputs, inputs=[protein_type, ligand_type], outputs=[protein_file, protein_fasta, ligand_file, ligand_smiles])
ligand_type.change(update_inputs, inputs=[protein_type, ligand_type], outputs=[protein_file, protein_fasta, ligand_file, ligand_smiles])
iface.launch()