Spaces:
Running on Zero
Running on Zero
Commit ·
6bb1bdf
1
Parent(s): 50fe1a2
finish the basic function
Browse files- app.py +185 -4
- dataset_descriptions.json +112 -0
- utils.py +286 -0
app.py
CHANGED
|
@@ -1,7 +1,188 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
def
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from huggingface_hub import HfApi, get_collection, list_collections
|
| 3 |
+
from utils import MolecularPropertyPredictionModel, task_types, dataset_descriptions
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import os
|
| 6 |
|
| 7 |
+
def get_models():
|
| 8 |
+
# this is the collection id for the molecular property prediction models
|
| 9 |
+
collection = get_collection("ChemFM/molecular-property-prediction-6710141ffc31f31a47d6fc0c")
|
| 10 |
+
models = dict()
|
| 11 |
+
for item in collection.items:
|
| 12 |
+
if item.item_type == "model":
|
| 13 |
+
item_name = item.item_id.split("/")[-1]
|
| 14 |
+
models[item_name] = item.item_id
|
| 15 |
+
assert item_name in task_types, f"{item_name} is not in the task_types"
|
| 16 |
+
assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions"
|
| 17 |
+
|
| 18 |
+
return models
|
| 19 |
|
| 20 |
+
candidate_models = get_models()
|
| 21 |
+
properties = list(candidate_models.keys())
|
| 22 |
+
model = MolecularPropertyPredictionModel()
|
| 23 |
+
|
| 24 |
+
def get_description(property_name):
|
| 25 |
+
return dataset_descriptions[property_name]
|
| 26 |
+
|
| 27 |
+
def predict_single_label(smiles, property_name):
|
| 28 |
+
adapter_id = candidate_models[property_name]
|
| 29 |
+
info = model.swith_adapter(property_name, adapter_id)
|
| 30 |
+
|
| 31 |
+
running_status = None
|
| 32 |
+
if info == "keep":
|
| 33 |
+
running_status = "Adapter is the same as the current one"
|
| 34 |
+
#print("Adapter is the same as the current one")
|
| 35 |
+
elif info == "switched":
|
| 36 |
+
running_status = "Adapter is switched successfully"
|
| 37 |
+
#print("Adapter is switched successfully")
|
| 38 |
+
elif info == "error":
|
| 39 |
+
running_status = "Adapter is not found"
|
| 40 |
+
#print("Adapter is not found")
|
| 41 |
+
return "NA", running_status
|
| 42 |
+
else:
|
| 43 |
+
running_status = "Unknown error"
|
| 44 |
+
return "NA", running_status
|
| 45 |
+
|
| 46 |
+
#prediction = model.predict(smiles, property_name, adapter_id)
|
| 47 |
+
prediction = model.predict_single_smiles(smiles, task_types[property_name])
|
| 48 |
+
if prediction is None:
|
| 49 |
+
return "NA", "Invalid SMILES string"
|
| 50 |
+
|
| 51 |
+
# if the prediction is a float, round it to 3 decimal places
|
| 52 |
+
if isinstance(prediction, float):
|
| 53 |
+
prediction = round(prediction, 3)
|
| 54 |
+
|
| 55 |
+
return prediction, "Prediction is done"
|
| 56 |
+
|
| 57 |
+
def predict_file(file, property_name):
|
| 58 |
+
adapter_id = candidate_models[property_name]
|
| 59 |
+
info = model.swith_adapter(property_name, adapter_id)
|
| 60 |
+
|
| 61 |
+
running_status = None
|
| 62 |
+
if info == "keep":
|
| 63 |
+
running_status = "Adapter is the same as the current one"
|
| 64 |
+
#print("Adapter is the same as the current one")
|
| 65 |
+
elif info == "switched":
|
| 66 |
+
running_status = "Adapter is switched successfully"
|
| 67 |
+
#print("Adapter is switched successfully")
|
| 68 |
+
elif info == "error":
|
| 69 |
+
running_status = "Adapter is not found"
|
| 70 |
+
#print("Adapter is not found")
|
| 71 |
+
return None, None, file, running_status
|
| 72 |
+
else:
|
| 73 |
+
running_status = "Unknown error"
|
| 74 |
+
return None, None, file, running_status
|
| 75 |
+
|
| 76 |
+
df = pd.read_csv(file)
|
| 77 |
+
# we have already checked the file contains the "smiles" column
|
| 78 |
+
df = model.predict_file(df, task_types[property_name])
|
| 79 |
+
# we should save this file to the disk to be downloaded
|
| 80 |
+
# rename the file to have "_prediction" suffix
|
| 81 |
+
prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
|
| 82 |
+
print(file, prediction_file)
|
| 83 |
+
# save the file to the disk
|
| 84 |
+
df.to_csv(prediction_file, index=False)
|
| 85 |
+
|
| 86 |
+
return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), prediction_file, "Prediction is done"
|
| 87 |
+
|
| 88 |
+
def validate_file(file):
|
| 89 |
+
try:
|
| 90 |
+
if file.endswith(".csv"):
|
| 91 |
+
df = pd.read_csv(file)
|
| 92 |
+
if "smiles" not in df.columns:
|
| 93 |
+
# we should clear the file input
|
| 94 |
+
return "Invalid file content. The csv file must contain column named 'smiles'", \
|
| 95 |
+
None, gr.update(visible=False), gr.update(visible=False)
|
| 96 |
+
|
| 97 |
+
# check the length of the smiles
|
| 98 |
+
length = len(df["smiles"])
|
| 99 |
+
|
| 100 |
+
elif file.endswith(".smi"):
|
| 101 |
+
return "Invalid file extension", \
|
| 102 |
+
None, gr.update(visible=False), gr.update(visible=False)
|
| 103 |
+
|
| 104 |
+
else:
|
| 105 |
+
return "Invalid file extension", \
|
| 106 |
+
None, gr.update(visible=False), gr.update(visible=False)
|
| 107 |
+
except Exception as e:
|
| 108 |
+
return "Invalid file content.", \
|
| 109 |
+
None, gr.update(visible=False), gr.update(visible=False)
|
| 110 |
+
|
| 111 |
+
if length > 100:
|
| 112 |
+
return "The space does not support the file containing more than 100 SMILES", \
|
| 113 |
+
None, gr.update(visible=False), gr.update(visible=False)
|
| 114 |
+
|
| 115 |
+
return "Valid file", file, gr.update(visible=True), gr.update(visible=False)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def raise_error(status):
|
| 119 |
+
if status != "Valid file":
|
| 120 |
+
raise gr.Error(status)
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def clear_file(download_button):
|
| 125 |
+
# we might need to delete the prediction file and uploaded file
|
| 126 |
+
prediction_path = download_button
|
| 127 |
+
print(prediction_path)
|
| 128 |
+
if prediction_path and os.path.exists(prediction_path):
|
| 129 |
+
os.remove(prediction_path)
|
| 130 |
+
original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv")
|
| 131 |
+
original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi")
|
| 132 |
+
if os.path.exists(original_data_file_0):
|
| 133 |
+
os.remove(original_data_file_0)
|
| 134 |
+
if os.path.exists(original_data_file_1):
|
| 135 |
+
os.remove(original_data_file_1)
|
| 136 |
+
#if os.path.exists(file):
|
| 137 |
+
# os.remove(file)
|
| 138 |
+
#prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
|
| 139 |
+
#if os.path.exists(prediction_file):
|
| 140 |
+
# os.remove(prediction_file)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
return gr.update(visible=False), gr.update(visible=False), None
|
| 144 |
+
|
| 145 |
+
def build_inference():
|
| 146 |
+
|
| 147 |
+
with gr.Blocks() as demo:
|
| 148 |
+
# first row - Dropdown input
|
| 149 |
+
#with gr.Row():
|
| 150 |
+
dropdown = gr.Dropdown(properties, label="Property", value=properties[0])
|
| 151 |
+
description_box = gr.Textbox(label="Property description", lines=5,
|
| 152 |
+
interactive=False,
|
| 153 |
+
value=dataset_descriptions[properties[0]])
|
| 154 |
+
# third row - Textbox input and prediction label
|
| 155 |
+
with gr.Row(equal_height=True):
|
| 156 |
+
with gr.Column():
|
| 157 |
+
textbox = gr.Textbox(label="Molecule SMILES", type="text", placeholder="Provide a SMILES string here",
|
| 158 |
+
lines=1)
|
| 159 |
+
predict_single_smiles_button = gr.Button("Predict", size='sm')
|
| 160 |
+
prediction = gr.Label("Prediction will appear here")
|
| 161 |
+
|
| 162 |
+
running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False)
|
| 163 |
+
|
| 164 |
+
input_file = gr.File(label="Molecule file",
|
| 165 |
+
file_count='single',
|
| 166 |
+
file_types=[".smi", ".csv"], height=300)
|
| 167 |
+
predict_file_button = gr.Button("Predict", size='sm', visible=False)
|
| 168 |
+
download_button = gr.DownloadButton("Download", size='sm', visible=False)
|
| 169 |
+
|
| 170 |
+
# dropdown change event
|
| 171 |
+
dropdown.change(get_description, inputs=dropdown, outputs=description_box)
|
| 172 |
+
# predict single button click event
|
| 173 |
+
predict_single_smiles_button.click(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label])
|
| 174 |
+
# input file upload event
|
| 175 |
+
file_status = gr.State()
|
| 176 |
+
input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
|
| 177 |
+
# input file clear event
|
| 178 |
+
input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
|
| 179 |
+
# predict file button click event
|
| 180 |
+
predict_file_button.click(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, input_file, running_terminal_label])
|
| 181 |
+
|
| 182 |
+
return demo
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
demo = build_inference()
|
| 186 |
+
|
| 187 |
+
if __name__ == '__main__':
|
| 188 |
+
demo.launch()
|
dataset_descriptions.json
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"ADMET_Caco2_Wang": {
|
| 3 |
+
"task_type": "regression",
|
| 4 |
+
"description": "predict drug permeability, measured in cm/s, using the Caco-2 cell line as an in vitro model to simulate human intestinal tissue permeability",
|
| 5 |
+
"num_molecules": 906
|
| 6 |
+
},
|
| 7 |
+
"ADMET_Bioavailability_Ma": {
|
| 8 |
+
"task_type": "classification",
|
| 9 |
+
"description": "predict oral bioavailability with binary labels, indicating the rate and extent a drug becomes available at its site of action",
|
| 10 |
+
"num_molecules": 640
|
| 11 |
+
},
|
| 12 |
+
"ADMET_Lipophilicity_AstraZeneca": {
|
| 13 |
+
"task_type": "regression",
|
| 14 |
+
"description": "predict lipophilicity with continuous labels, measured as a log-ratio, indicating a drug's ability to dissolve in lipid environments",
|
| 15 |
+
"num_molecules": 4200
|
| 16 |
+
},
|
| 17 |
+
"ADMET_Solubility_AqSolDB": {
|
| 18 |
+
"task_type": "regression",
|
| 19 |
+
"description": "predict aqueous solubility with continuous labels, measured in log mol/L, indicating a drug's ability to dissolve in water",
|
| 20 |
+
"num_molecules": 9982
|
| 21 |
+
},
|
| 22 |
+
"ADMET_HIA_Hou": {
|
| 23 |
+
"task_type": "classification",
|
| 24 |
+
"description": "predict human intestinal absorption (HIA) with binary labels, indicating a drug's ability to be absorbed into the bloodstream",
|
| 25 |
+
"num_molecules": 578
|
| 26 |
+
},
|
| 27 |
+
"ADMET_Pgp_Broccatelli": {
|
| 28 |
+
"task_type": "classification",
|
| 29 |
+
"description": "predict P-glycoprotein (Pgp) inhibition with binary labels, indicating a drug's potential to alter bioavailability and overcome multidrug resistance",
|
| 30 |
+
"num_molecules": 1212
|
| 31 |
+
},
|
| 32 |
+
"ADMET_BBB_Martins": {
|
| 33 |
+
"task_type": "classification",
|
| 34 |
+
"description": "predict blood-brain barrier permeability with binary labels, indicating a drug's ability to penetrate the barrier to reach the brain",
|
| 35 |
+
"num_molecules": 1915
|
| 36 |
+
},
|
| 37 |
+
"ADMET_PPBR_AZ": {
|
| 38 |
+
"task_type": "regression",
|
| 39 |
+
"description": "predict plasma protein binding rate with continuous labels, indicating the percentage of a drug bound to plasma proteins in the blood",
|
| 40 |
+
"num_molecules": 1797
|
| 41 |
+
},
|
| 42 |
+
"ADMET_VDss_Lombardo": {
|
| 43 |
+
"task_type": "regression",
|
| 44 |
+
"description": "predict the volume of distribution at steady state (VDss), indicating drug concentration in tissues versus blood",
|
| 45 |
+
"num_molecules": 1130
|
| 46 |
+
},
|
| 47 |
+
"ADMET_CYP2C9_Veith": {
|
| 48 |
+
"task_type": "classification",
|
| 49 |
+
"description": "predict CYP2C9 inhibition with binary labels, indicating the drug's ability to inhibit the CYP2C9 enzyme involved in metabolism",
|
| 50 |
+
"num_molecules": 12092
|
| 51 |
+
},
|
| 52 |
+
"ADMET_CYP2D6_Veith": {
|
| 53 |
+
"task_type": "classification",
|
| 54 |
+
"description": "predict CYP2D6 inhibition with binary labels, indicating the drug's potential to inhibit the CYP2D6 enzyme involved in metabolism",
|
| 55 |
+
"num_molecules": 13130
|
| 56 |
+
},
|
| 57 |
+
"ADMET_CYP3A4_Veith": {
|
| 58 |
+
"task_type": "classification",
|
| 59 |
+
"description": "predict CPY3A4 inhibition with binary labels, indicating the drug's ability to inhibit the CPY3A4 enzyme involved in metabolism",
|
| 60 |
+
"num_molecules": 12328
|
| 61 |
+
},
|
| 62 |
+
"ADMET_CYP2C9_Substrate_CarbonMangels": {
|
| 63 |
+
"task_type": "classification",
|
| 64 |
+
"description": "predict whether a drug is a substrate of the CYP2C9 enzyme with binary labels, indicating its potential to be metabolized",
|
| 65 |
+
"num_molecules": 666
|
| 66 |
+
},
|
| 67 |
+
"ADMET_CYP2D6_Substrate_CarbonMangels": {
|
| 68 |
+
"task_type": "classification",
|
| 69 |
+
"description": "predict whether a drug is a substrate of the CYP2D6 enzyme with binary labels, indicating its potential to be metabolized",
|
| 70 |
+
"num_molecules": 664
|
| 71 |
+
},
|
| 72 |
+
"ADMET_CYP3A4_Substrate_CarbonMangels": {
|
| 73 |
+
"task_type": "classification",
|
| 74 |
+
"description": "predict whether a drug is a substrate of the CYP3A4 enzyme with binary labels, indicating its potential to be metabolized",
|
| 75 |
+
"num_molecules": 667
|
| 76 |
+
},
|
| 77 |
+
"ADMET_Half_Life_Obach": {
|
| 78 |
+
"task_type": "regression",
|
| 79 |
+
"description": "predict the half-life duration of a drug, measured in hours, indicating the time for its concentration to reduce by half",
|
| 80 |
+
"num_molecules": 667
|
| 81 |
+
},
|
| 82 |
+
"ADMET_Clearance_Hepatocyte_AZ": {
|
| 83 |
+
"task_type": "regression",
|
| 84 |
+
"description": "predict drug clearance, measured in \u03bcL/min/10^6 cells, from hepatocyte experiments, indicating the rate at which the drug is removed from body",
|
| 85 |
+
"num_molecules": 1020
|
| 86 |
+
},
|
| 87 |
+
"ADMET_Clearance_Microsome_AZ": {
|
| 88 |
+
"task_type": "regression",
|
| 89 |
+
"description": "predict drug clearance, measured in mL/min/g, from microsome experiments, indicating the rate at which the drug is removed from body",
|
| 90 |
+
"num_molecules": 1102
|
| 91 |
+
},
|
| 92 |
+
"ADMET_LD50_Zhu": {
|
| 93 |
+
"task_type": "regression",
|
| 94 |
+
"description": "predict the acute toxicity of a drug, measured as the dose leading to lethal effects in log(kg/mol)",
|
| 95 |
+
"num_molecules": 7385
|
| 96 |
+
},
|
| 97 |
+
"ADMET_hERG": {
|
| 98 |
+
"task_type": "classification",
|
| 99 |
+
"description": "predict whether a drug blocks the hERG channel, which is crucial for heart rhythm, potentially leading to adverse effects",
|
| 100 |
+
"num_molecules": 648
|
| 101 |
+
},
|
| 102 |
+
"ADMET_AMES": {
|
| 103 |
+
"task_type": "classification",
|
| 104 |
+
"description": "predict whether a drug is mutagenic with binary labels, indicating its ability to induce genetic alterations",
|
| 105 |
+
"num_molecules": 7255
|
| 106 |
+
},
|
| 107 |
+
"ADMET_DILI": {
|
| 108 |
+
"task_type": "classification",
|
| 109 |
+
"description": "predict whether a drug can cause liver injury with binary labels, indicating its potential for hepatotoxicity",
|
| 110 |
+
"num_molecules": 475
|
| 111 |
+
}
|
| 112 |
+
}
|
utils.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
|
| 2 |
+
from typing import Optional, Dict, Sequence, List
|
| 3 |
+
import transformers
|
| 4 |
+
from peft import PeftModel
|
| 5 |
+
import torch
|
| 6 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from datasets import Dataset
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import numpy as np
|
| 12 |
+
from huggingface_hub import hf_hub_download
|
| 13 |
+
import os
|
| 14 |
+
import pickle
|
| 15 |
+
from sklearn import preprocessing
|
| 16 |
+
import json
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from rdkit import RDLogger, Chem
|
| 20 |
+
# Suppress RDKit INFO messages
|
| 21 |
+
RDLogger.DisableLog('rdApp.*')
|
| 22 |
+
|
| 23 |
+
# we have a dictionary to store the task types of the models
|
| 24 |
+
task_types = {
|
| 25 |
+
"admet_ppbr_az": "regression",
|
| 26 |
+
"admet_half_life_obach": "regression",
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
# read the dataset descriptions
|
| 30 |
+
with open("dataset_descriptions.json", "r") as f:
|
| 31 |
+
dataset_description_temp = json.load(f)
|
| 32 |
+
|
| 33 |
+
dataset_descriptions = dict()
|
| 34 |
+
|
| 35 |
+
for dataset in dataset_description_temp:
|
| 36 |
+
dataset_name = dataset.lower()
|
| 37 |
+
dataset_descriptions[dataset_name] = \
|
| 38 |
+
f"{dataset_name} is a {dataset_description_temp[dataset]['task_type']} task, " + \
|
| 39 |
+
f"where the goal is to {dataset_description_temp[dataset]['description']}."
|
| 40 |
+
|
| 41 |
+
class Scaler:
|
| 42 |
+
def __init__(self, log=False):
|
| 43 |
+
self.log = log
|
| 44 |
+
self.offset = None
|
| 45 |
+
self.scaler = None
|
| 46 |
+
|
| 47 |
+
def fit(self, y):
|
| 48 |
+
# make the values non-negative
|
| 49 |
+
self.offset = np.min([np.min(y), 0.0])
|
| 50 |
+
y = y.reshape(-1, 1) - self.offset
|
| 51 |
+
|
| 52 |
+
# scale the input data
|
| 53 |
+
if self.log:
|
| 54 |
+
y = np.log10(y + 1.0)
|
| 55 |
+
|
| 56 |
+
self.scaler = preprocessing.StandardScaler().fit(y)
|
| 57 |
+
|
| 58 |
+
def transform(self, y):
|
| 59 |
+
y = y.reshape(-1, 1) - self.offset
|
| 60 |
+
|
| 61 |
+
# scale the input data
|
| 62 |
+
if self.log:
|
| 63 |
+
y = np.log10(y + 1.0)
|
| 64 |
+
|
| 65 |
+
y_scale = self.scaler.transform(y)
|
| 66 |
+
|
| 67 |
+
return y_scale
|
| 68 |
+
|
| 69 |
+
def inverse_transform(self, y_scale):
|
| 70 |
+
y = self.scaler.inverse_transform(y_scale.reshape(-1, 1))
|
| 71 |
+
|
| 72 |
+
if self.log:
|
| 73 |
+
y = 10.0**y - 1.0
|
| 74 |
+
|
| 75 |
+
y = y + self.offset
|
| 76 |
+
|
| 77 |
+
return y
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def smart_tokenizer_and_embedding_resize(
|
| 81 |
+
special_tokens_dict: Dict,
|
| 82 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 83 |
+
model: transformers.PreTrainedModel,
|
| 84 |
+
non_special_tokens = None,
|
| 85 |
+
):
|
| 86 |
+
"""Resize tokenizer and embedding.
|
| 87 |
+
|
| 88 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
| 89 |
+
"""
|
| 90 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + tokenizer.add_tokens(non_special_tokens)
|
| 91 |
+
num_old_tokens = model.get_input_embeddings().weight.shape[0]
|
| 92 |
+
num_new_tokens = len(tokenizer) - num_old_tokens
|
| 93 |
+
if num_new_tokens == 0:
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 97 |
+
|
| 98 |
+
if num_new_tokens > 0:
|
| 99 |
+
input_embeddings_data = model.get_input_embeddings().weight.data
|
| 100 |
+
|
| 101 |
+
input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 102 |
+
|
| 103 |
+
input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
|
| 104 |
+
print(f"Resized tokenizer and embedding from {num_old_tokens} to {len(tokenizer)} tokens.")
|
| 105 |
+
|
| 106 |
+
@dataclass
|
| 107 |
+
class DataCollator(object):
|
| 108 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 109 |
+
source_max_len: int
|
| 110 |
+
molecule_start_str: str
|
| 111 |
+
end_str: str
|
| 112 |
+
|
| 113 |
+
def augment_molecule(self, molecule: str) -> str:
|
| 114 |
+
return self.sme.augment([molecule])[0]
|
| 115 |
+
|
| 116 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 117 |
+
|
| 118 |
+
sources = []
|
| 119 |
+
targets = []
|
| 120 |
+
|
| 121 |
+
for example in instances:
|
| 122 |
+
smiles = example['smiles'].strip()
|
| 123 |
+
smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
|
| 124 |
+
|
| 125 |
+
# get the properties except the smiles and mol_id cols
|
| 126 |
+
#props = [example[col] if example[col] is not None else np.nan for col in sorted(example.keys()) if col not in ['smiles', 'is_aug']]
|
| 127 |
+
source = f"{self.molecule_start_str}{smiles}{self.end_str}"
|
| 128 |
+
sources.append(source)
|
| 129 |
+
|
| 130 |
+
# Tokenize
|
| 131 |
+
tokenized_sources_with_prompt = self.tokenizer(
|
| 132 |
+
sources,
|
| 133 |
+
max_length=self.source_max_len,
|
| 134 |
+
truncation=True,
|
| 135 |
+
add_special_tokens=False,
|
| 136 |
+
)
|
| 137 |
+
input_ids = [torch.tensor(tokenized_source) for tokenized_source in tokenized_sources_with_prompt['input_ids']]
|
| 138 |
+
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
|
| 139 |
+
|
| 140 |
+
data_dict = {
|
| 141 |
+
'input_ids': input_ids,
|
| 142 |
+
'attention_mask': input_ids.ne(self.tokenizer.pad_token_id),
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
return data_dict
|
| 146 |
+
|
| 147 |
+
class MolecularPropertyPredictionModel():
|
| 148 |
+
def __init__(self):
|
| 149 |
+
self.adapter_name = None
|
| 150 |
+
|
| 151 |
+
# we need to keep track of the paths of adapter scalers
|
| 152 |
+
# we don't want to download the same scaler multiple times
|
| 153 |
+
self.apapter_scaler_path = dict()
|
| 154 |
+
|
| 155 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
| 156 |
+
|
| 157 |
+
# load the base model
|
| 158 |
+
config = AutoConfig.from_pretrained(
|
| 159 |
+
"ChemFM/ChemFM-3B",
|
| 160 |
+
num_labels=1,
|
| 161 |
+
finetuning_task="classification", # this is not about our task type
|
| 162 |
+
trust_remote_code=True,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
self.base_model = AutoModelForSequenceClassification.from_pretrained(
|
| 166 |
+
"ChemFM/ChemFM-3B",
|
| 167 |
+
config=config,
|
| 168 |
+
device_map="cpu",
|
| 169 |
+
trust_remote_code=True,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# load the tokenizer
|
| 173 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 174 |
+
"ChemFM/admet_ppbr_az",
|
| 175 |
+
trust_remote_code=True,
|
| 176 |
+
)
|
| 177 |
+
special_tokens_dict = dict(pad_token=DEFAULT_PAD_TOKEN)
|
| 178 |
+
smart_tokenizer_and_embedding_resize(
|
| 179 |
+
special_tokens_dict=special_tokens_dict,
|
| 180 |
+
tokenizer=self.tokenizer,
|
| 181 |
+
model=self.base_model
|
| 182 |
+
)
|
| 183 |
+
self.base_model.config.pad_token_id = self.tokenizer.pad_token_id
|
| 184 |
+
|
| 185 |
+
self.data_collator = DataCollator(
|
| 186 |
+
tokenizer=self.tokenizer,
|
| 187 |
+
source_max_len=512,
|
| 188 |
+
molecule_start_str="<molstart>",
|
| 189 |
+
end_str="<eos>",
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def swith_adapter(self, adapter_name, adapter_id):
|
| 194 |
+
# return flag:
|
| 195 |
+
# keep: adapter is the same as the current one
|
| 196 |
+
# switched: adapter is switched successfully
|
| 197 |
+
# error: adapter is not found
|
| 198 |
+
|
| 199 |
+
if adapter_name == self.adapter_name:
|
| 200 |
+
return "keep"
|
| 201 |
+
# switch adapter
|
| 202 |
+
try:
|
| 203 |
+
self.adapter_name = adapter_name
|
| 204 |
+
self.lora_model = PeftModel.from_pretrained(self.base_model, adapter_id)
|
| 205 |
+
if adapter_name not in self.apapter_scaler_path:
|
| 206 |
+
self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl")
|
| 207 |
+
if os.path.exists(self.apapter_scaler_path[adapter_name]):
|
| 208 |
+
self.scaler = pickle.load(open(self.apapter_scaler_path[adapter_name], "rb"))
|
| 209 |
+
else:
|
| 210 |
+
self.scaler = None
|
| 211 |
+
|
| 212 |
+
return "switched"
|
| 213 |
+
except Exception as e:
|
| 214 |
+
# handle error
|
| 215 |
+
return "error"
|
| 216 |
+
|
| 217 |
+
def predict(self, valid_df, task_type):
|
| 218 |
+
test_dataset = Dataset.from_pandas(valid_df)
|
| 219 |
+
# construct the dataloader
|
| 220 |
+
test_loader = torch.utils.data.DataLoader(
|
| 221 |
+
test_dataset,
|
| 222 |
+
batch_size=4,
|
| 223 |
+
collate_fn=self.data_collator,
|
| 224 |
+
)
|
| 225 |
+
# predict
|
| 226 |
+
|
| 227 |
+
y_pred = []
|
| 228 |
+
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
| 229 |
+
with torch.no_grad():
|
| 230 |
+
batch = {k: v.to(self.lora_model.device) for k, v in batch.items()}
|
| 231 |
+
outputs = self.lora_model(**batch)
|
| 232 |
+
if task_type == "regression": # TODO: check if the model is regression or classification
|
| 233 |
+
y_pred.append(outputs.logits.cpu().detach().numpy())
|
| 234 |
+
else:
|
| 235 |
+
y_pred.append((torch.sigmoid(outputs.logits) > 0.5).cpu().detach().numpy())
|
| 236 |
+
|
| 237 |
+
y_pred = np.concatenate(y_pred, axis=0)
|
| 238 |
+
if task_type=="regression" and self.scaler is not None:
|
| 239 |
+
y_pred = self.scaler.inverse_transform(y_pred)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
return y_pred
|
| 243 |
+
|
| 244 |
+
def predict_single_smiles(self, smiles, task_type):
|
| 245 |
+
assert task_type in ["regression", "classification"]
|
| 246 |
+
|
| 247 |
+
# check the SMILES string is valid
|
| 248 |
+
if not Chem.MolFromSmiles(smiles):
|
| 249 |
+
return None
|
| 250 |
+
|
| 251 |
+
valid_df = pd.DataFrame([smiles], columns=['smiles'])
|
| 252 |
+
results = self.predict(valid_df, task_type)
|
| 253 |
+
# predict
|
| 254 |
+
return results.item()
|
| 255 |
+
|
| 256 |
+
def predict_file(self, df, task_type):
|
| 257 |
+
# we should add the index first
|
| 258 |
+
df = df.reset_index()
|
| 259 |
+
# we need to check the SMILES strings are valid, the invalid ones will be moved to the last
|
| 260 |
+
valid_idx = []
|
| 261 |
+
invalid_idx = []
|
| 262 |
+
for idx, smiles in enumerate(df['smiles']):
|
| 263 |
+
if Chem.MolFromSmiles(smiles):
|
| 264 |
+
valid_idx.append(idx)
|
| 265 |
+
else:
|
| 266 |
+
invalid_idx.append(idx)
|
| 267 |
+
valid_df = df.loc[valid_idx]
|
| 268 |
+
# get the smiles list
|
| 269 |
+
valid_df_smiles = valid_df['smiles'].tolist()
|
| 270 |
+
|
| 271 |
+
input_df = pd.DataFrame(valid_df_smiles, columns=['smiles'])
|
| 272 |
+
results = self.predict(input_df, task_type)
|
| 273 |
+
|
| 274 |
+
# add the results to the dataframe
|
| 275 |
+
df.loc[valid_idx, 'prediction'] = results
|
| 276 |
+
df.loc[invalid_idx, 'prediction'] = np.nan
|
| 277 |
+
|
| 278 |
+
# drop the index column
|
| 279 |
+
df = df.drop(columns=['index'])
|
| 280 |
+
|
| 281 |
+
# phrase file
|
| 282 |
+
return df
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
|