nninva commited on
Commit
ea594ff
·
verified ·
1 Parent(s): fbf1ca3

first commit

Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import joblib
3
+ import torch
4
+ import numpy as np
5
+ from transformers import AutoTokenizer, AutoModel, EsmTokenizer, EsmModel
6
+ from rdkit import Chem
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # Load the pretrained models
11
+ esm = "facebook/esm2_t12_35M_UR50D" # generate protein embeddings
12
+ esm_tokenizer = EsmTokenizer.from_pretrained(esm)
13
+ esm_model = EsmModel.from_pretrained(esm).to(device)
14
+ chemberta = "DeepChem/ChemBERTa-10M-MTR" # generate ligand embeddings
15
+ chemberta_tokenizer = AutoTokenizer.from_pretrained(chemberta)
16
+ chemberta_model = AutoModel.from_pretrained(chemberta).to(device)
17
+ scaler = joblib.load("scaler.pkl")
18
+ pca = joblib.load("pca.pkl")
19
+ svr = joblib.load("svr_model.pkl")
20
+
21
+ def generate_protein_embedding(protein):
22
+ # Generate FASTA string
23
+ mol = Chem.MolFromPDBFile(protein)
24
+ if not mol:
25
+ print("Could not convert file to protein molecule")
26
+ return None
27
+ fasta = Chem.MolToFASTA(mol).splitlines()[1]
28
+
29
+ # Generate protein embedding
30
+ esm_model.eval()
31
+ with torch.no_grad():
32
+ inputs = esm_tokenizer(fasta, return_tensors="pt", padding=True, truncation=True, max_length=1024).to(device)
33
+ outputs = esm_model(**inputs)
34
+ embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy() # Extract last layer and mean pooling
35
+ return embedding
36
+
37
+ def generate_ligand_embedding(ligand):
38
+ # Generate SMILES string
39
+ mol = Chem.MolFromMol2File(ligand)
40
+ if not mol:
41
+ print("Could not convert file to ligand molecule")
42
+ return None
43
+ smiles = Chem.MolToSmiles(mol)
44
+
45
+ # Generate ligand embedding
46
+ chemberta_model.eval()
47
+ with torch.no_grad():
48
+ inputs = chemberta_tokenizer(smiles, return_tensors="pt", padding=True, truncation=True).to(device)
49
+ outputs = chemberta_model(**inputs)
50
+ embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
51
+ return embedding
52
+
53
+ def value_conversion(logKa):
54
+ logKd = logKa * -1
55
+ Kd = 10 ** (logKd)
56
+
57
+ if Kd >= 1e-3:
58
+ return f"{Kd * 1e3:.4f} mM" # Millimolar
59
+ elif Kd >= 1e-6:
60
+ return f"{Kd * 1e6:.4f} µM" # Micromolar
61
+ elif Kd >= 1e-9:
62
+ return f"{Kd * 1e9:.4f} nM" # Nanomolar
63
+ else:
64
+ return f"{Kd * 1e12:.4f} pM" # Picomolar
65
+
66
+ def predict_affinity(protein_file, ligand_file):
67
+ protein = generate_protein_embedding(protein_file)
68
+ ligand = generate_ligand_embedding(ligand_file)
69
+ if protein is None:
70
+ return "Unable to parse protein .pdb file"
71
+ if ligand is None:
72
+ return "Unable to parse ligand .pdb file"
73
+ embedding = np.concatenate((protein, ligand), axis=1)
74
+
75
+ # Apply scaling and PCA
76
+ svr_input = scaler.transform(embedding)
77
+ svr_input = pca.transform(svr_input)
78
+
79
+ # Predict the log binding affinity
80
+ log_prediction = svr.predict(svr_input)[0]
81
+ affinity_value = value_conversion(log_prediction)
82
+ return f"Predicted Binding Affinity:\nlogKa = {log_prediction:.4f}\nKd = {affinity_value}"
83
+
84
+ # Gradio interface
85
+ iface = gr.Interface(
86
+ fn=predict_affinity,
87
+ inputs=[gr.File(label="Protein .pdb file"), gr.File(label="Ligand .mol2 file")],
88
+ outputs="text",
89
+ title="Predict Protein-Ligand Binding Affinity",
90
+ description="Upload the protein and ligand files to predict the binding affinity of the protein-ligand complex.",
91
+ )
92
+
93
+ # Run Gradio App
94
+ if __name__ == "__main__":
95
+ iface.launch()