nninva commited on
Commit
d96e3d2
·
verified ·
1 Parent(s): 0f8a299

Chaning the interface to allow use input FASTA/SMILES or files

Browse files
Files changed (1) hide show
  1. app.py +102 -34
app.py CHANGED
@@ -18,36 +18,66 @@ scaler = joblib.load("scaler.pkl")
18
  pca = joblib.load("pca.pkl")
19
  svr = joblib.load("svr_model.pkl")
20
 
21
- def generate_protein_embedding(protein):
22
- # Generate FASTA string
23
- mol = Chem.MolFromPDBFile(protein)
24
- if not mol:
25
- print("Could not convert file to protein molecule")
26
- return None
27
- fasta = Chem.MolToFASTA(mol).splitlines()[1]
28
-
29
- # Generate protein embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  esm_model.eval()
31
  with torch.no_grad():
32
  inputs = esm_tokenizer(fasta, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
33
  outputs = esm_model(**inputs)
34
- embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() # Extract last layer and mean pooling
35
  return embedding
36
 
37
- def generate_ligand_embedding(ligand):
38
- # Generate SMILES string
39
- mol = Chem.MolFromMol2File(ligand)
40
- if not mol:
41
- print("Could not convert file to ligand molecule")
42
- return None
43
- smiles = Chem.MolToSmiles(mol)
44
 
45
- # Generate ligand embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  chemberta_model.eval()
47
  with torch.no_grad():
48
  inputs = chemberta_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True).to(device)
49
  outputs = chemberta_model(**inputs)
50
- embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
51
  return embedding
52
 
53
  def value_conversion(logKa):
@@ -63,9 +93,13 @@ def value_conversion(logKa):
63
  else:
64
  return f"{Kd * 1e12:.4f} pM" # Picomolar
65
 
66
- def predict_affinity(protein_file, ligand_file):
67
- protein = generate_protein_embedding(protein_file)
68
- ligand = generate_ligand_embedding(ligand_file)
 
 
 
 
69
  if protein is None:
70
  return "Unable to parse protein .pdb file"
71
  if ligand is None:
@@ -81,15 +115,49 @@ def predict_affinity(protein_file, ligand_file):
81
  affinity_value = value_conversion(log_prediction)
82
  return f"Predicted Binding Affinity:\nlogKa = {log_prediction:.4f}\nKd = {affinity_value}"
83
 
84
- # Gradio interface
85
- iface = gr.Interface(
86
- fn=predict_affinity,
87
- inputs=[gr.File(label="Protein .pdb file"), gr.File(label="Ligand .mol2 file")],
88
- outputs="text",
89
- title="Predict Protein-Ligand Binding Affinity",
90
- description="Upload the protein and ligand files to predict the binding affinity of the protein-ligand complex.",
91
- )
92
-
93
- # Run Gradio App
94
- if __name__ == "__main__":
95
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  pca = joblib.load("pca.pkl")
19
  svr = joblib.load("svr_model.pkl")
20
 
21
+ # def generate_protein_embedding(protein):
22
+ # # Generate FASTA string
23
+ # mol = Chem.MolFromPDBFile(protein)
24
+ # if not mol:
25
+ # print("Could not convert file to protein molecule")
26
+ # return None
27
+ # fasta = Chem.MolToFASTA(mol).splitlines()[1]
28
+
29
+ # # Generate protein embedding
30
+ # esm_model.eval()
31
+ # with torch.no_grad():
32
+ # inputs = esm_tokenizer(fasta, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
33
+ # outputs = esm_model(**inputs)
34
+ # embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() # Extract last layer and mean pooling
35
+ # return embedding
36
+ def generate_protein_embedding(protein_input, input_type):
37
+ if input_type == "File":
38
+ mol = Chem.MolFromPDBFile(protein_input)
39
+ if not mol:
40
+ return None
41
+ fasta = Chem.MolToFASTA(mol).splitlines()[1]
42
+ else:
43
+ fasta = protein_input.strip()
44
+
45
  esm_model.eval()
46
  with torch.no_grad():
47
  inputs = esm_tokenizer(fasta, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
48
  outputs = esm_model(**inputs)
49
+ embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
50
  return embedding
51
 
52
+ # def generate_ligand_embedding(ligand):
53
+ # # Generate SMILES string
54
+ # mol = Chem.MolFromMol2File(ligand)
55
+ # if not mol:
56
+ # print("Could not convert file to ligand molecule")
57
+ # return None
58
+ # smiles = Chem.MolToSmiles(mol)
59
 
60
+ # # Generate ligand embedding
61
+ # chemberta_model.eval()
62
+ # with torch.no_grad():
63
+ # inputs = chemberta_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True).to(device)
64
+ # outputs = chemberta_model(**inputs)
65
+ # embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
66
+ # return embedding
67
+ def generate_ligand_embedding(ligand_input, input_type):
68
+ if input_type == "File":
69
+ mol = Chem.MolFromMol2File(ligand_input)
70
+ if not mol:
71
+ return None
72
+ smiles = Chem.MolToSmiles(mol)
73
+ else:
74
+ smiles = ligand_input.strip()
75
+
76
  chemberta_model.eval()
77
  with torch.no_grad():
78
  inputs = chemberta_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True).to(device)
79
  outputs = chemberta_model(**inputs)
80
+ embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
81
  return embedding
82
 
83
  def value_conversion(logKa):
 
93
  else:
94
  return f"{Kd * 1e12:.4f} pM" # Picomolar
95
 
96
+ # def predict_affinity(protein_file, ligand_file):
97
+ # protein = generate_protein_embedding(protein_file)
98
+ # ligand = generate_ligand_embedding(ligand_file)
99
+ def predict_affinity(protein_input, protein_type, ligand_input, ligand_type):
100
+ protein = generate_protein_embedding(protein_input, protein_type)
101
+ ligand = generate_ligand_embedding(ligand_input, ligand_type)
102
+
103
  if protein is None:
104
  return "Unable to parse protein .pdb file"
105
  if ligand is None:
 
115
  affinity_value = value_conversion(log_prediction)
116
  return f"Predicted Binding Affinity:\nlogKa = {log_prediction:.4f}\nKd = {affinity_value}"
117
 
118
+ # # Gradio interface
119
+ # iface = gr.Interface(
120
+ # fn=predict_affinity,
121
+ # inputs=[gr.File(label="Protein .pdb file"), gr.File(label="Ligand .mol2 file")],
122
+ # outputs="text",
123
+ # title="Predict Protein-Ligand Binding Affinity",
124
+ # description="Upload the protein and ligand files to predict the binding affinity of the protein-ligand complex.",
125
+ # )
126
+
127
+ # # Run Gradio App
128
+ # if __name__ == "__main__":
129
+ # iface.launch()
130
+
131
+ def update_inputs(protein_type, ligand_type):
132
+ return (
133
+ gr.update(visible=(protein_type == "File")), # Protein file input
134
+ gr.update(visible=(protein_type == "FASTA")), # Protein FASTA input
135
+ gr.update(visible=(ligand_type == "File")), # Ligand file input
136
+ gr.update(visible=(ligand_type == "SMILES")) # Ligand SMILES input
137
+ )
138
+
139
+ with gr.Blocks() as iface:
140
+ gr.Markdown("# Predict Protein-Ligand Binding Affinity")
141
+ gr.Markdown("Upload files or enter FASTA/SMILES strings to predict binding affinity.")
142
+
143
+ protein_type = gr.Radio(["File", "FASTA"], label="Protein Input Type", value="File")
144
+ protein_file = gr.File(label="Protein .pdb file", visible=True)
145
+ protein_fasta = gr.Textbox(label="Protein FASTA sequence", visible=False)
146
+
147
+ ligand_type = gr.Radio(["File", "SMILES"], label="Ligand Input Type", value="File")
148
+ ligand_file = gr.File(label="Ligand .mol2 file", visible=True)
149
+ ligand_smiles = gr.Textbox(label="Ligand SMILES string", visible=False)
150
+
151
+ output = gr.Textbox(label="Prediction Result", lines=3)
152
+
153
+ submit_btn = gr.Button("Predict")
154
+ submit_btn.click(
155
+ predict_affinity,
156
+ inputs=[protein_file, protein_type, ligand_file, ligand_type],
157
+ outputs=output
158
+ )
159
+
160
+ protein_type.change(update_inputs, inputs=[protein_type, ligand_type], outputs=[protein_file, protein_fasta, ligand_file, ligand_smiles])
161
+ ligand_type.change(update_inputs, inputs=[protein_type, ligand_type], outputs=[protein_file, protein_fasta, ligand_file, ligand_smiles])
162
+
163
+ iface.launch()