Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import joblib | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModel, EsmTokenizer, EsmModel | |
| from rdkit import Chem | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load the pretrained models | |
| esm = "facebook/esm2_t12_35M_UR50D" # generate protein embeddings | |
| esm_tokenizer = EsmTokenizer.from_pretrained(esm) | |
| esm_model = EsmModel.from_pretrained(esm).to(device) | |
| chemberta = "DeepChem/ChemBERTa-10M-MTR" # generate ligand embeddings | |
| chemberta_tokenizer = AutoTokenizer.from_pretrained(chemberta) | |
| chemberta_model = AutoModel.from_pretrained(chemberta).to(device) | |
| scaler = joblib.load("scaler.pkl") | |
| pca = joblib.load("pca.pkl") | |
| svr = joblib.load("svr_model.pkl") | |
| def generate_protein_embedding(protein_input, input_type): | |
| # Generate FASTA string from file | |
| if input_type == "File": | |
| mol = Chem.MolFromPDBFile(protein_input) | |
| if not mol: | |
| return None | |
| fasta = Chem.MolToFASTA(mol).splitlines()[1] | |
| # The input was FASTA | |
| else: | |
| fasta = protein_input.strip() | |
| # Generate protein embedding | |
| esm_model.eval() | |
| with torch.no_grad(): | |
| inputs = esm_tokenizer(fasta, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device) | |
| outputs = esm_model(**inputs) | |
| embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
| return embedding | |
| def generate_ligand_embedding(ligand_input, input_type): | |
| if input_type == "File": | |
| mol = Chem.MolFromMol2File(ligand_input) | |
| if not mol: | |
| return None | |
| smiles = Chem.MolToSmiles(mol) | |
| else: | |
| smiles = ligand_input.strip() | |
| # Generate compounds embeddings from SMILES | |
| chemberta_model.eval() | |
| with torch.no_grad(): | |
| inputs = chemberta_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True).to(device) | |
| outputs = chemberta_model(**inputs) | |
| embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() | |
| return embedding | |
| # Convert -logKd predicted value to Kd | |
| def value_conversion(logKa): | |
| logKd = logKa * -1 | |
| Kd = 10 ** (logKd) | |
| if Kd >= 1e-3: | |
| return f"{Kd * 1e3:.4f} mM" # Millimolar | |
| elif Kd >= 1e-6: | |
| return f"{Kd * 1e6:.4f} µM" # Micromolar | |
| elif Kd >= 1e-9: | |
| return f"{Kd * 1e9:.4f} nM" # Nanomolar | |
| else: | |
| return f"{Kd * 1e12:.4f} pM" # Picomolar | |
| def predict_affinity(protein_file, protein_fasta, protein_type, ligand_file, ligand_smiles, ligand_type): | |
| # Determine protein input | |
| if protein_file is not None: | |
| protein_input = protein_file | |
| protein_type = "File" | |
| elif protein_fasta is not None: | |
| protein_input = protein_fasta.strip() | |
| protein_type = "FASTA" | |
| else: | |
| return "Error: No valid protein input provided." | |
| # Determine ligand input | |
| if ligand_file is not None: | |
| ligand_input = ligand_file | |
| ligand_type = "File" | |
| elif ligand_smiles is not None: | |
| ligand_input = ligand_smiles.strip() | |
| ligand_type = "SMILES" | |
| else: | |
| return "Error: No valid ligand input provided." | |
| # Get embeddings | |
| protein = generate_protein_embedding(protein_input, protein_type) | |
| ligand = generate_ligand_embedding(ligand_input, ligand_type) | |
| if protein is None: | |
| return "Unable to parse protein .pdb file" | |
| if ligand is None: | |
| return "Unable to parse ligand .pdb file" | |
| embedding = np.concatenate((protein, ligand), axis=1) | |
| # Apply scaling and PCA | |
| svr_input = scaler.transform(embedding) | |
| svr_input = pca.transform(svr_input) | |
| # Predict the log binding affinity | |
| log_prediction = svr.predict(svr_input)[0] | |
| affinity_value = value_conversion(log_prediction) | |
| return f"Predicted Binding Affinity:\nlogKa = {log_prediction:.4f}\nKd = {affinity_value}" | |
| def update_inputs(protein_type, ligand_type): | |
| # Updates visibility and interactivity dynamically | |
| return ( | |
| gr.update(visible=(protein_type == "File"), interactive=(protein_type == "File")), | |
| gr.update(visible=(protein_type == "FASTA"), interactive=(protein_type == "FASTA")), | |
| gr.update(visible=(ligand_type == "File"), interactive=(ligand_type == "File")), | |
| gr.update(visible=(ligand_type == "SMILES"), interactive=(ligand_type == "SMILES")) | |
| ) | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# Predict Protein-Ligand Binding Affinity") | |
| gr.Markdown("Upload protein and compound files or enter FASTA/SMILES strings to predict binding affinity.") | |
| with gr.Row(): | |
| protein_type = gr.Radio(["File", "FASTA"], label="Protein Input Type", value="File") | |
| ligand_type = gr.Radio(["File", "SMILES"], label="Ligand Input Type", value="File") | |
| protein_file = gr.File(label="Protein .pdb file", visible=True, interactive=True) | |
| protein_fasta = gr.Textbox(label="Protein FASTA sequence", visible=False, interactive=True) | |
| ligand_file = gr.File(label="Ligand .mol2 file", visible=True, interactive=True) | |
| ligand_smiles = gr.Textbox(label="Ligand SMILES string", visible=False, interactive=True) | |
| output = gr.Textbox(label="Prediction Result", lines=3) | |
| submit_btn = gr.Button("Predict") | |
| submit_btn.click( | |
| predict_affinity, | |
| inputs=[protein_file, protein_fasta, protein_type, ligand_file, ligand_smiles, ligand_type], | |
| outputs=output | |
| ) | |
| protein_type.change(update_inputs, inputs=[protein_type, ligand_type], outputs=[protein_file, protein_fasta, ligand_file, ligand_smiles]) | |
| ligand_type.change(update_inputs, inputs=[protein_type, ligand_type], outputs=[protein_file, protein_fasta, ligand_file, ligand_smiles]) | |
| iface.launch() |