Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
ec780ac
1
Parent(s):
2a36eec
freeze other components when the prediction is processing
Browse files
app.py
CHANGED
|
@@ -25,65 +25,75 @@ def get_description(property_name):
|
|
| 25 |
return dataset_descriptions[property_name]
|
| 26 |
|
| 27 |
def predict_single_label(smiles, property_name):
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
return prediction, "Prediction is done"
|
| 56 |
|
| 57 |
def predict_file(file, property_name):
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
-
return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), prediction_file, "Prediction is done"
|
| 87 |
|
| 88 |
def validate_file(file):
|
| 89 |
try:
|
|
@@ -166,18 +176,57 @@ def build_inference():
|
|
| 166 |
file_types=[".smi", ".csv"], height=300)
|
| 167 |
predict_file_button = gr.Button("Predict", size='sm', visible=False)
|
| 168 |
download_button = gr.DownloadButton("Download", size='sm', visible=False)
|
|
|
|
| 169 |
|
| 170 |
# dropdown change event
|
| 171 |
dropdown.change(get_description, inputs=dropdown, outputs=description_box)
|
| 172 |
# predict single button click event
|
| 173 |
-
predict_single_smiles_button.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
# input file upload event
|
| 175 |
file_status = gr.State()
|
| 176 |
input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
|
| 177 |
# input file clear event
|
| 178 |
input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
|
| 179 |
# predict file button click event
|
| 180 |
-
predict_file_button.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
return demo
|
| 183 |
|
|
|
|
| 25 |
return dataset_descriptions[property_name]
|
| 26 |
|
| 27 |
def predict_single_label(smiles, property_name):
|
| 28 |
+
try:
|
| 29 |
+
adapter_id = candidate_models[property_name]
|
| 30 |
+
info = model.swith_adapter(property_name, adapter_id)
|
| 31 |
+
|
| 32 |
+
running_status = None
|
| 33 |
+
if info == "keep":
|
| 34 |
+
running_status = "Adapter is the same as the current one"
|
| 35 |
+
#print("Adapter is the same as the current one")
|
| 36 |
+
elif info == "switched":
|
| 37 |
+
running_status = "Adapter is switched successfully"
|
| 38 |
+
#print("Adapter is switched successfully")
|
| 39 |
+
elif info == "error":
|
| 40 |
+
running_status = "Adapter is not found"
|
| 41 |
+
#print("Adapter is not found")
|
| 42 |
+
return "NA", running_status
|
| 43 |
+
else:
|
| 44 |
+
running_status = "Unknown error"
|
| 45 |
+
return "NA", running_status
|
| 46 |
|
| 47 |
+
#prediction = model.predict(smiles, property_name, adapter_id)
|
| 48 |
+
prediction = model.predict_single_smiles(smiles, task_types[property_name])
|
| 49 |
+
if prediction is None:
|
| 50 |
+
return "NA", "Invalid SMILES string"
|
| 51 |
|
| 52 |
+
# if the prediction is a float, round it to 3 decimal places
|
| 53 |
+
if isinstance(prediction, float):
|
| 54 |
+
prediction = round(prediction, 3)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
# no matter what the error is, we should return
|
| 57 |
+
print(e)
|
| 58 |
+
return "NA", "Prediction failed"
|
| 59 |
|
| 60 |
return prediction, "Prediction is done"
|
| 61 |
|
| 62 |
def predict_file(file, property_name):
|
| 63 |
+
try:
|
| 64 |
+
adapter_id = candidate_models[property_name]
|
| 65 |
+
info = model.swith_adapter(property_name, adapter_id)
|
| 66 |
+
|
| 67 |
+
running_status = None
|
| 68 |
+
if info == "keep":
|
| 69 |
+
running_status = "Adapter is the same as the current one"
|
| 70 |
+
#print("Adapter is the same as the current one")
|
| 71 |
+
elif info == "switched":
|
| 72 |
+
running_status = "Adapter is switched successfully"
|
| 73 |
+
#print("Adapter is switched successfully")
|
| 74 |
+
elif info == "error":
|
| 75 |
+
running_status = "Adapter is not found"
|
| 76 |
+
#print("Adapter is not found")
|
| 77 |
+
return None, None, file, running_status
|
| 78 |
+
else:
|
| 79 |
+
running_status = "Unknown error"
|
| 80 |
+
return None, None, file, running_status
|
| 81 |
|
| 82 |
+
df = pd.read_csv(file)
|
| 83 |
+
# we have already checked the file contains the "smiles" column
|
| 84 |
+
df = model.predict_file(df, task_types[property_name])
|
| 85 |
+
# we should save this file to the disk to be downloaded
|
| 86 |
+
# rename the file to have "_prediction" suffix
|
| 87 |
+
prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv")
|
| 88 |
+
print(file, prediction_file)
|
| 89 |
+
# save the file to the disk
|
| 90 |
+
df.to_csv(prediction_file, index=False)
|
| 91 |
+
except Exception as e:
|
| 92 |
+
# no matter what the error is, we should return
|
| 93 |
+
print(e)
|
| 94 |
+
return None, None, gr.update(visible=False), file, "Prediction failed"
|
| 95 |
|
| 96 |
+
return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), gr.update(visible=False), prediction_file, "Prediction is done"
|
| 97 |
|
| 98 |
def validate_file(file):
|
| 99 |
try:
|
|
|
|
| 176 |
file_types=[".smi", ".csv"], height=300)
|
| 177 |
predict_file_button = gr.Button("Predict", size='sm', visible=False)
|
| 178 |
download_button = gr.DownloadButton("Download", size='sm', visible=False)
|
| 179 |
+
stop_button = gr.Button("Stop", size='sm', visible=False)
|
| 180 |
|
| 181 |
# dropdown change event
|
| 182 |
dropdown.change(get_description, inputs=dropdown, outputs=description_box)
|
| 183 |
# predict single button click event
|
| 184 |
+
predict_single_smiles_button.click(lambda:(gr.update(interactive=False),
|
| 185 |
+
gr.update(interactive=False),
|
| 186 |
+
gr.update(interactive=False),
|
| 187 |
+
gr.update(interactive=False),
|
| 188 |
+
gr.update(interactive=False),
|
| 189 |
+
gr.update(interactive=False),
|
| 190 |
+
gr.update(interactive=False),
|
| 191 |
+
gr.update(interactive=False),
|
| 192 |
+
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
|
| 193 |
+
.then(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label])\
|
| 194 |
+
.then(lambda:(gr.update(interactive=True),
|
| 195 |
+
gr.update(interactive=True),
|
| 196 |
+
gr.update(interactive=True),
|
| 197 |
+
gr.update(interactive=True),
|
| 198 |
+
gr.update(interactive=True),
|
| 199 |
+
gr.update(interactive=True),
|
| 200 |
+
gr.update(interactive=True),
|
| 201 |
+
gr.update(interactive=True),
|
| 202 |
+
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
|
| 203 |
# input file upload event
|
| 204 |
file_status = gr.State()
|
| 205 |
input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status)
|
| 206 |
# input file clear event
|
| 207 |
input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file])
|
| 208 |
# predict file button click event
|
| 209 |
+
predict_file_event = predict_file_button.click(lambda:(gr.update(interactive=False),
|
| 210 |
+
gr.update(interactive=False),
|
| 211 |
+
gr.update(interactive=False),
|
| 212 |
+
gr.update(interactive=False, visible=True),
|
| 213 |
+
gr.update(interactive=False),
|
| 214 |
+
gr.update(interactive=True, visible=False),
|
| 215 |
+
gr.update(interactive=False),
|
| 216 |
+
gr.update(interactive=False),
|
| 217 |
+
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
|
| 218 |
+
.then(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, stop_button, input_file, running_terminal_label])\
|
| 219 |
+
.then(lambda:(gr.update(interactive=True),
|
| 220 |
+
gr.update(interactive=True),
|
| 221 |
+
gr.update(interactive=True),
|
| 222 |
+
gr.update(interactive=True),
|
| 223 |
+
gr.update(interactive=True),
|
| 224 |
+
gr.update(interactive=True),
|
| 225 |
+
gr.update(interactive=True),
|
| 226 |
+
gr.update(interactive=True),
|
| 227 |
+
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])
|
| 228 |
+
# stop button click event
|
| 229 |
+
#stop_button.click(fn=None, inputs=None, outputs=None, cancels=[predict_file_event])
|
| 230 |
|
| 231 |
return demo
|
| 232 |
|
utils.py
CHANGED
|
@@ -201,7 +201,7 @@ class MolecularPropertyPredictionModel():
|
|
| 201 |
self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
|
| 202 |
|
| 203 |
#self.base_model.to("cuda")
|
| 204 |
-
print(self.base_model)
|
| 205 |
|
| 206 |
def swith_adapter(self, adapter_name, adapter_id):
|
| 207 |
# return flag:
|
|
@@ -220,6 +220,7 @@ class MolecularPropertyPredictionModel():
|
|
| 220 |
#print(self.lora_model)
|
| 221 |
|
| 222 |
self.base_model.set_adapter(adapter_name)
|
|
|
|
| 223 |
|
| 224 |
#if adapter_name not in self.apapter_scaler_path:
|
| 225 |
# self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
|
|
@@ -244,8 +245,8 @@ class MolecularPropertyPredictionModel():
|
|
| 244 |
batch_size=16,
|
| 245 |
collate_fn=self.data_collator,
|
| 246 |
)
|
| 247 |
-
# predict
|
| 248 |
|
|
|
|
| 249 |
y_pred = []
|
| 250 |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
| 251 |
with torch.no_grad():
|
|
|
|
| 201 |
self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
|
| 202 |
|
| 203 |
#self.base_model.to("cuda")
|
| 204 |
+
#print(self.base_model)
|
| 205 |
|
| 206 |
def swith_adapter(self, adapter_name, adapter_id):
|
| 207 |
# return flag:
|
|
|
|
| 220 |
#print(self.lora_model)
|
| 221 |
|
| 222 |
self.base_model.set_adapter(adapter_name)
|
| 223 |
+
self.base_model.eval()
|
| 224 |
|
| 225 |
#if adapter_name not in self.apapter_scaler_path:
|
| 226 |
# self.apapter_scaler_path[adapter_name] = hf_hub_download(adapter_id, filename="scaler.pkl", token = os.environ.get("TOKEN"))
|
|
|
|
| 245 |
batch_size=16,
|
| 246 |
collate_fn=self.data_collator,
|
| 247 |
)
|
|
|
|
| 248 |
|
| 249 |
+
# predict
|
| 250 |
y_pred = []
|
| 251 |
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc="Evaluating"):
|
| 252 |
with torch.no_grad():
|