Spaces:
Sleeping
Sleeping
Chaning the interface to allow use input FASTA/SMILES or files
Browse files
app.py
CHANGED
|
@@ -18,36 +18,66 @@ scaler = joblib.load("scaler.pkl")
|
|
| 18 |
pca = joblib.load("pca.pkl")
|
| 19 |
svr = joblib.load("svr_model.pkl")
|
| 20 |
|
| 21 |
-
def generate_protein_embedding(protein):
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
esm_model.eval()
|
| 31 |
with torch.no_grad():
|
| 32 |
inputs = esm_tokenizer(fasta, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
|
| 33 |
outputs = esm_model(**inputs)
|
| 34 |
-
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
|
| 35 |
return embedding
|
| 36 |
|
| 37 |
-
def generate_ligand_embedding(ligand):
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
chemberta_model.eval()
|
| 47 |
with torch.no_grad():
|
| 48 |
inputs = chemberta_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True).to(device)
|
| 49 |
outputs = chemberta_model(**inputs)
|
| 50 |
-
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
|
| 51 |
return embedding
|
| 52 |
|
| 53 |
def value_conversion(logKa):
|
|
@@ -63,9 +93,13 @@ def value_conversion(logKa):
|
|
| 63 |
else:
|
| 64 |
return f"{Kd * 1e12:.4f} pM" # Picomolar
|
| 65 |
|
| 66 |
-
def predict_affinity(protein_file, ligand_file):
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
if protein is None:
|
| 70 |
return "Unable to parse protein .pdb file"
|
| 71 |
if ligand is None:
|
|
@@ -81,15 +115,49 @@ def predict_affinity(protein_file, ligand_file):
|
|
| 81 |
affinity_value = value_conversion(log_prediction)
|
| 82 |
return f"Predicted Binding Affinity:\nlogKa = {log_prediction:.4f}\nKd = {affinity_value}"
|
| 83 |
|
| 84 |
-
# Gradio interface
|
| 85 |
-
iface = gr.Interface(
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
)
|
| 92 |
-
|
| 93 |
-
# Run Gradio App
|
| 94 |
-
if __name__ == "__main__":
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
pca = joblib.load("pca.pkl")
|
| 19 |
svr = joblib.load("svr_model.pkl")
|
| 20 |
|
| 21 |
+
# def generate_protein_embedding(protein):
|
| 22 |
+
# # Generate FASTA string
|
| 23 |
+
# mol = Chem.MolFromPDBFile(protein)
|
| 24 |
+
# if not mol:
|
| 25 |
+
# print("Could not convert file to protein molecule")
|
| 26 |
+
# return None
|
| 27 |
+
# fasta = Chem.MolToFASTA(mol).splitlines()[1]
|
| 28 |
+
|
| 29 |
+
# # Generate protein embedding
|
| 30 |
+
# esm_model.eval()
|
| 31 |
+
# with torch.no_grad():
|
| 32 |
+
# inputs = esm_tokenizer(fasta, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
|
| 33 |
+
# outputs = esm_model(**inputs)
|
| 34 |
+
# embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() # Extract last layer and mean pooling
|
| 35 |
+
# return embedding
|
| 36 |
+
def generate_protein_embedding(protein_input, input_type):
|
| 37 |
+
if input_type == "File":
|
| 38 |
+
mol = Chem.MolFromPDBFile(protein_input)
|
| 39 |
+
if not mol:
|
| 40 |
+
return None
|
| 41 |
+
fasta = Chem.MolToFASTA(mol).splitlines()[1]
|
| 42 |
+
else:
|
| 43 |
+
fasta = protein_input.strip()
|
| 44 |
+
|
| 45 |
esm_model.eval()
|
| 46 |
with torch.no_grad():
|
| 47 |
inputs = esm_tokenizer(fasta, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
|
| 48 |
outputs = esm_model(**inputs)
|
| 49 |
+
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
|
| 50 |
return embedding
|
| 51 |
|
| 52 |
+
# def generate_ligand_embedding(ligand):
|
| 53 |
+
# # Generate SMILES string
|
| 54 |
+
# mol = Chem.MolFromMol2File(ligand)
|
| 55 |
+
# if not mol:
|
| 56 |
+
# print("Could not convert file to ligand molecule")
|
| 57 |
+
# return None
|
| 58 |
+
# smiles = Chem.MolToSmiles(mol)
|
| 59 |
|
| 60 |
+
# # Generate ligand embedding
|
| 61 |
+
# chemberta_model.eval()
|
| 62 |
+
# with torch.no_grad():
|
| 63 |
+
# inputs = chemberta_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True).to(device)
|
| 64 |
+
# outputs = chemberta_model(**inputs)
|
| 65 |
+
# embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
|
| 66 |
+
# return embedding
|
| 67 |
+
def generate_ligand_embedding(ligand_input, input_type):
|
| 68 |
+
if input_type == "File":
|
| 69 |
+
mol = Chem.MolFromMol2File(ligand_input)
|
| 70 |
+
if not mol:
|
| 71 |
+
return None
|
| 72 |
+
smiles = Chem.MolToSmiles(mol)
|
| 73 |
+
else:
|
| 74 |
+
smiles = ligand_input.strip()
|
| 75 |
+
|
| 76 |
chemberta_model.eval()
|
| 77 |
with torch.no_grad():
|
| 78 |
inputs = chemberta_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True).to(device)
|
| 79 |
outputs = chemberta_model(**inputs)
|
| 80 |
+
embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
|
| 81 |
return embedding
|
| 82 |
|
| 83 |
def value_conversion(logKa):
|
|
|
|
| 93 |
else:
|
| 94 |
return f"{Kd * 1e12:.4f} pM" # Picomolar
|
| 95 |
|
| 96 |
+
# def predict_affinity(protein_file, ligand_file):
|
| 97 |
+
# protein = generate_protein_embedding(protein_file)
|
| 98 |
+
# ligand = generate_ligand_embedding(ligand_file)
|
| 99 |
+
def predict_affinity(protein_input, protein_type, ligand_input, ligand_type):
|
| 100 |
+
protein = generate_protein_embedding(protein_input, protein_type)
|
| 101 |
+
ligand = generate_ligand_embedding(ligand_input, ligand_type)
|
| 102 |
+
|
| 103 |
if protein is None:
|
| 104 |
return "Unable to parse protein .pdb file"
|
| 105 |
if ligand is None:
|
|
|
|
| 115 |
affinity_value = value_conversion(log_prediction)
|
| 116 |
return f"Predicted Binding Affinity:\nlogKa = {log_prediction:.4f}\nKd = {affinity_value}"
|
| 117 |
|
| 118 |
+
# # Gradio interface
|
| 119 |
+
# iface = gr.Interface(
|
| 120 |
+
# fn=predict_affinity,
|
| 121 |
+
# inputs=[gr.File(label="Protein .pdb file"), gr.File(label="Ligand .mol2 file")],
|
| 122 |
+
# outputs="text",
|
| 123 |
+
# title="Predict Protein-Ligand Binding Affinity",
|
| 124 |
+
# description="Upload the protein and ligand files to predict the binding affinity of the protein-ligand complex.",
|
| 125 |
+
# )
|
| 126 |
+
|
| 127 |
+
# # Run Gradio App
|
| 128 |
+
# if __name__ == "__main__":
|
| 129 |
+
# iface.launch()
|
| 130 |
+
|
| 131 |
+
def update_inputs(protein_type, ligand_type):
|
| 132 |
+
return (
|
| 133 |
+
gr.update(visible=(protein_type == "File")), # Protein file input
|
| 134 |
+
gr.update(visible=(protein_type == "FASTA")), # Protein FASTA input
|
| 135 |
+
gr.update(visible=(ligand_type == "File")), # Ligand file input
|
| 136 |
+
gr.update(visible=(ligand_type == "SMILES")) # Ligand SMILES input
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
with gr.Blocks() as iface:
|
| 140 |
+
gr.Markdown("# Predict Protein-Ligand Binding Affinity")
|
| 141 |
+
gr.Markdown("Upload files or enter FASTA/SMILES strings to predict binding affinity.")
|
| 142 |
+
|
| 143 |
+
protein_type = gr.Radio(["File", "FASTA"], label="Protein Input Type", value="File")
|
| 144 |
+
protein_file = gr.File(label="Protein .pdb file", visible=True)
|
| 145 |
+
protein_fasta = gr.Textbox(label="Protein FASTA sequence", visible=False)
|
| 146 |
+
|
| 147 |
+
ligand_type = gr.Radio(["File", "SMILES"], label="Ligand Input Type", value="File")
|
| 148 |
+
ligand_file = gr.File(label="Ligand .mol2 file", visible=True)
|
| 149 |
+
ligand_smiles = gr.Textbox(label="Ligand SMILES string", visible=False)
|
| 150 |
+
|
| 151 |
+
output = gr.Textbox(label="Prediction Result", lines=3)
|
| 152 |
+
|
| 153 |
+
submit_btn = gr.Button("Predict")
|
| 154 |
+
submit_btn.click(
|
| 155 |
+
predict_affinity,
|
| 156 |
+
inputs=[protein_file, protein_type, ligand_file, ligand_type],
|
| 157 |
+
outputs=output
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
protein_type.change(update_inputs, inputs=[protein_type, ligand_type], outputs=[protein_file, protein_fasta, ligand_file, ligand_smiles])
|
| 161 |
+
ligand_type.change(update_inputs, inputs=[protein_type, ligand_type], outputs=[protein_file, protein_fasta, ligand_file, ligand_smiles])
|
| 162 |
+
|
| 163 |
+
iface.launch()
|