Spaces:
Runtime error
Runtime error
Eachan Johnson
commited on
Commit
·
a11f59c
1
Parent(s):
7dbc2af
Update examples and add second plot
Browse files
app.py
CHANGED
|
@@ -23,6 +23,7 @@ from schemist.tables import converter
|
|
| 23 |
import torch
|
| 24 |
|
| 25 |
CACHE = "./cache"
|
|
|
|
| 26 |
HEADER_FILE = os.path.join("sources", "header.md")
|
| 27 |
MODEL_REPOS = {
|
| 28 |
"Klebsiella pneumoniae": "hf://scbirlab/spark-dv-fp-2503-kpn",
|
|
@@ -78,6 +79,10 @@ def convert_one(
|
|
| 78 |
input_representation: str = 'smiles',
|
| 79 |
output_representation: Union[Iterable[str], str] = 'smiles'
|
| 80 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
df = pd.DataFrame({
|
| 83 |
input_representation: _clean_split_input(strings),
|
|
@@ -91,23 +96,17 @@ def convert_one(
|
|
| 91 |
)
|
| 92 |
|
| 93 |
|
| 94 |
-
def
|
| 95 |
-
|
| 96 |
-
input_representation: str = 'smiles',
|
| 97 |
predict: Union[Iterable[str], str] = 'smiles',
|
| 98 |
extra_metrics: Optional[Union[Iterable[str], str]] = None
|
| 99 |
-
):
|
|
|
|
|
|
|
| 100 |
if extra_metrics is None:
|
| 101 |
extra_metrics = []
|
| 102 |
else:
|
| 103 |
extra_metrics = cast(extra_metrics, to=list)
|
| 104 |
-
prediction_df = convert_one(
|
| 105 |
-
strings=strings,
|
| 106 |
-
input_representation=input_representation,
|
| 107 |
-
output_representation=['id', 'pubchem_name', 'pubchem_id', 'smiles', 'inchikey', "mwt", "clogp"],
|
| 108 |
-
)
|
| 109 |
-
species_to_predict = cast(predict, to=list)
|
| 110 |
-
prediction_cols = []
|
| 111 |
for species in species_to_predict:
|
| 112 |
message = f"Predicting for species: {species}"
|
| 113 |
print_err(message)
|
|
@@ -116,7 +115,7 @@ def predict_one(
|
|
| 116 |
this_features = this_modelbox._input_cols
|
| 117 |
this_labels = this_modelbox._label_cols
|
| 118 |
this_prediction_input = (
|
| 119 |
-
|
| 120 |
.rename(columns={
|
| 121 |
"smiles": this_features[0],
|
| 122 |
})
|
|
@@ -132,10 +131,10 @@ def predict_one(
|
|
| 132 |
).with_format("numpy")["__prediction__"].flatten()
|
| 133 |
print(prediction)
|
| 134 |
this_col = f"{species}: predicted MIC (µM)"
|
| 135 |
-
|
| 136 |
prediction_cols.append(this_col)
|
| 137 |
this_col = f"{species}: predicted MIC (µg / mL)"
|
| 138 |
-
|
| 139 |
prediction_cols.append(this_col)
|
| 140 |
|
| 141 |
for extra_metric in extra_metrics:
|
|
@@ -155,10 +154,33 @@ def predict_one(
|
|
| 155 |
)
|
| 156 |
.with_format("numpy")
|
| 157 |
)
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
return gr.DataFrame(
|
| 161 |
-
prediction_df[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
visible=True
|
| 163 |
)
|
| 164 |
|
|
@@ -209,70 +231,38 @@ def predict_file(
|
|
| 209 |
else:
|
| 210 |
extra_metrics = cast(extra_metrics, to=list)
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
prediction_df = convert_file(
|
| 213 |
df,
|
| 214 |
column=column,
|
| 215 |
input_representation=input_representation,
|
| 216 |
output_representation=["id", "smiles", "inchikey", "mwt", "clogp"],
|
| 217 |
)
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
)
|
| 239 |
-
print(this_prediction_input)
|
| 240 |
-
prediction = this_modelbox.predict(
|
| 241 |
-
data=this_prediction_input,
|
| 242 |
-
features=this_features,
|
| 243 |
-
labels=this_labels,
|
| 244 |
-
aggregator="mean",
|
| 245 |
-
cache=CACHE,
|
| 246 |
-
).with_format("numpy")["__prediction__"].flatten()
|
| 247 |
-
print(prediction)
|
| 248 |
-
this_col = f"{species}: predicted MIC (µM)"
|
| 249 |
-
prediction_df[this_col] = np.power(10., -prediction) * 1e6
|
| 250 |
-
prediction_cols.append(this_col)
|
| 251 |
-
this_col = f"{species}: predicted MIC (µg / mL)"
|
| 252 |
-
prediction_df[this_col] = np.power(10., -prediction) * 1e3 * prediction_df["mwt"]
|
| 253 |
-
prediction_cols.append(this_col)
|
| 254 |
-
|
| 255 |
-
for extra_metric in extra_metrics:
|
| 256 |
-
message = f"Calculating {extra_metric} for species: {species}"
|
| 257 |
-
print_err(message)
|
| 258 |
-
gr.Info(message, duration=10)
|
| 259 |
-
# this_modelbox._input_training_data = this_modelbox._input_training_data.remove_columns([this_modelbox._in_key])
|
| 260 |
-
this_col = f"{species}: {extra_metric}"
|
| 261 |
-
prediction_cols.append(this_col)
|
| 262 |
-
print(">>>", this_modelbox._input_training_data)
|
| 263 |
-
print(">>>", this_modelbox._input_training_data.format)
|
| 264 |
-
print(">>>", this_modelbox._in_key, this_modelbox._out_key)
|
| 265 |
-
this_extra = (
|
| 266 |
-
EXTRA_METRICS[extra_metric](
|
| 267 |
-
this_modelbox,
|
| 268 |
-
this_prediction_input,
|
| 269 |
-
)
|
| 270 |
-
.with_format("numpy")
|
| 271 |
-
)
|
| 272 |
-
prediction_df[this_col] = this_extra[this_extra.column_names[-1]]
|
| 273 |
-
other_cols = [col for col in prediction_df if col not in ['id', 'inchikey', 'smiles', "mwt", "clogp"] + [column] + prediction_cols]
|
| 274 |
-
|
| 275 |
-
return prediction_df[['id', 'inchikey'] + [column] + prediction_cols + other_cols + ['smiles', "mwt", "clogp"]]
|
| 276 |
|
| 277 |
def draw_one(
|
| 278 |
strings: Union[Iterable[str], str],
|
|
@@ -293,31 +283,35 @@ def draw_one(
|
|
| 293 |
legends=["\n".join(items) for items in zip(*_ids.values())],
|
| 294 |
)
|
| 295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
-
|
|
|
|
| 298 |
df,
|
| 299 |
-
|
| 300 |
-
|
| 301 |
color: Optional[str] = None,
|
| 302 |
):
|
| 303 |
print_err(df.head())
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
y_title = f"Observed ({ycol})"
|
| 307 |
-
cols = ["id", "inchikey", "smiles", "mwt", "clogp", xcol, ycol]
|
| 308 |
-
color_title = color
|
| 309 |
if color is not None and color not in cols:
|
| 310 |
cols.append(color)
|
| 311 |
cols = list(set(cols))
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
x_title = f"Predicted log10[MIC(µM)]"
|
| 316 |
|
| 317 |
return gr.ScatterPlot(
|
| 318 |
value=df[cols],
|
| 319 |
-
x=
|
| 320 |
-
y=
|
| 321 |
color=color,
|
| 322 |
x_title=x_title,
|
| 323 |
y_title=y_title,
|
|
@@ -327,14 +321,32 @@ def plot_pred_vs_observed(
|
|
| 327 |
)
|
| 328 |
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
def download_table(
|
| 331 |
df: pd.DataFrame
|
| 332 |
) -> str:
|
| 333 |
df_hash = nm.hash(pd.util.hash_pandas_object(df).values)
|
| 334 |
-
filename = f"
|
| 335 |
df.to_csv(filename, index=False)
|
| 336 |
return gr.DownloadButton(value=filename, visible=True)
|
| 337 |
|
|
|
|
| 338 |
with gr.Blocks() as demo:
|
| 339 |
|
| 340 |
with open(HEADER_FILE, 'r') as f:
|
|
@@ -379,7 +391,7 @@ with gr.Blocks() as demo:
|
|
| 379 |
]),
|
| 380 |
list(MODEL_REPOS)[0],
|
| 381 |
list(EXTRA_METRICS)[:2],
|
| 382 |
-
|
| 383 |
[
|
| 384 |
'\n'.join([
|
| 385 |
"C[C@H]1[C@H]([C@H](C[C@@H](O1)O[C@H]2C[C@@](CC3=C2C(=C4C(=C3O)C(=O)C5=C(C4=O)C(=CC=C5)OC)O)(C(=O)CO)O)N)O",
|
|
@@ -399,6 +411,7 @@ with gr.Blocks() as demo:
|
|
| 399 |
"COC1=CC(=CC(=C1OC)OC)CC2=CN=C(N=C2N)N",
|
| 400 |
"CC1=CC(=NO1)NS(=O)(=O)C2=CC=C(C=C2)N",
|
| 401 |
"C1[C@@H]([C@H]([C@@H]([C@H]([C@@H]1NC(=O)[C@H](CCN)O)O[C@@H]2[C@@H]([C@H]([C@@H]([C@H](O2)CO)O)N)O)O)O[C@@H]3[C@@H]([C@H]([C@@H]([C@H](O3)CN)O)O)O)N\nC1=CN=CC=C1C(=O)NN",
|
|
|
|
| 402 |
]),
|
| 403 |
list(MODEL_REPOS)[0],
|
| 404 |
list(EXTRA_METRICS)[:2],
|
|
@@ -420,10 +433,37 @@ with gr.Blocks() as demo:
|
|
| 420 |
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)CC4)N=C3",
|
| 421 |
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)[C@@H](C4)N)N=C3",
|
| 422 |
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)[C@H](CC4)[NH3+])N=C3.[Cl-]",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
]),
|
| 424 |
list(MODEL_REPOS)[0],
|
| 425 |
list(EXTRA_METRICS)[:2],
|
| 426 |
-
], #
|
| 427 |
|
| 428 |
],
|
| 429 |
example_labels=[
|
|
@@ -431,8 +471,9 @@ with gr.Blocks() as demo:
|
|
| 431 |
"Doxorubicin, Ampicillin, Amoxicillin, Meropenem, Tetracycline, Anhydrotetracycline",
|
| 432 |
"Halicin, Abaucin, Trimethoprim, Sulfamethoxazole, Amikacin, Isoniazid",
|
| 433 |
"Murepavadin, Vancomycin, Zosurabalpin, Plazomicin, Gentamicin, Rifampicin",
|
| 434 |
-
"Debio-1452, Debio-1452-NH3, Fabimycin",
|
| 435 |
-
|
|
|
|
| 436 |
],
|
| 437 |
inputs=[input_line, output_species_single, extra_metric],
|
| 438 |
cache_mode="eager",
|
|
@@ -476,7 +517,7 @@ with gr.Blocks() as demo:
|
|
| 476 |
outputs=download_single
|
| 477 |
)
|
| 478 |
|
| 479 |
-
with gr.Tab("Predict on structures from a file (max.
|
| 480 |
input_file = gr.File(
|
| 481 |
label="Upload a table of chemical compounds here",
|
| 482 |
file_types=[".xlsx", ".csv", ".tsv", ".txt"],
|
|
@@ -524,14 +565,36 @@ with gr.Blocks() as demo:
|
|
| 524 |
)
|
| 525 |
with gr.Row():
|
| 526 |
observed_col = gr.Dropdown(
|
| 527 |
-
label="Observed column (y-axis) for
|
| 528 |
choices=[],
|
| 529 |
value=None,
|
| 530 |
interactive=True,
|
| 531 |
visible=False,
|
| 532 |
)
|
| 533 |
color_col = gr.Dropdown(
|
| 534 |
-
label="Color for
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
choices=[],
|
| 536 |
value=None,
|
| 537 |
interactive=True,
|
|
@@ -544,38 +607,65 @@ with gr.Blocks() as demo:
|
|
| 544 |
file_examples = gr.Examples(
|
| 545 |
examples=[
|
| 546 |
[
|
| 547 |
-
"example-data/stokes2020-eco
|
| 548 |
"SMILES",
|
| 549 |
"Klebsiella pneumoniae",
|
| 550 |
"Mean_Inhibition",
|
| 551 |
"Klebsiella pneumoniae: Doubtscore",
|
| 552 |
-
list(EXTRA_METRICS)[:3]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
],
|
| 554 |
example_labels=[
|
| 555 |
-
"Stokes J. et al., Cell, 2020",
|
|
|
|
|
|
|
| 556 |
],
|
| 557 |
inputs=[input_file, input_column, output_species, observed_col, color_col, extra_metric_file],
|
| 558 |
cache_mode="eager",
|
| 559 |
)
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
|
| 568 |
file_examples.load_input_event.then(
|
| 569 |
-
|
| 570 |
-
inputs=[input_file],
|
| 571 |
-
outputs=[input_data, input_column],
|
| 572 |
)
|
| 573 |
input_file.upload(
|
| 574 |
-
|
| 575 |
-
inputs=[input_file],
|
| 576 |
-
outputs=[input_data, input_column]
|
| 577 |
)
|
| 578 |
-
go_button2.click(
|
| 579 |
predict_file,
|
| 580 |
inputs=[
|
| 581 |
input_data,
|
|
@@ -591,18 +681,17 @@ with gr.Blocks() as demo:
|
|
| 591 |
download_table,
|
| 592 |
inputs=input_data,
|
| 593 |
outputs=download
|
| 594 |
-
).then(
|
| 595 |
-
partial(get_dropdown_options, _type="number"),
|
| 596 |
-
inputs=[input_data],
|
| 597 |
-
outputs=[observed_col],
|
| 598 |
-
).then(
|
| 599 |
-
partial(get_dropdown_options, _type="number"),
|
| 600 |
-
inputs=[input_data],
|
| 601 |
-
outputs=[color_col],
|
| 602 |
).then(
|
| 603 |
lambda: gr.Button(visible=True),
|
| 604 |
-
outputs=[plot_button]
|
| 605 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
|
| 607 |
plot_button.click(
|
| 608 |
plot_pred_vs_observed,
|
|
@@ -612,7 +701,16 @@ with gr.Blocks() as demo:
|
|
| 612 |
observed_col,
|
| 613 |
color_col,
|
| 614 |
],
|
| 615 |
-
outputs=pred_vs_observed,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
)
|
| 617 |
|
| 618 |
if __name__ == "__main__":
|
|
|
|
| 23 |
import torch
|
| 24 |
|
| 25 |
CACHE = "./cache"
|
| 26 |
+
MAX_ROWS = 4000
|
| 27 |
HEADER_FILE = os.path.join("sources", "header.md")
|
| 28 |
MODEL_REPOS = {
|
| 29 |
"Klebsiella pneumoniae": "hf://scbirlab/spark-dv-fp-2503-kpn",
|
|
|
|
| 79 |
input_representation: str = 'smiles',
|
| 80 |
output_representation: Union[Iterable[str], str] = 'smiles'
|
| 81 |
):
|
| 82 |
+
output_representation = cast(output_representation, to=list)
|
| 83 |
+
for rep in output_representation:
|
| 84 |
+
message = f"Converting from {input_representation} to {rep}..."
|
| 85 |
+
gr.Info(message, duration=10)
|
| 86 |
|
| 87 |
df = pd.DataFrame({
|
| 88 |
input_representation: _clean_split_input(strings),
|
|
|
|
| 96 |
)
|
| 97 |
|
| 98 |
|
| 99 |
+
def _prediction_loop(
|
| 100 |
+
df: pd.DataFrame,
|
|
|
|
| 101 |
predict: Union[Iterable[str], str] = 'smiles',
|
| 102 |
extra_metrics: Optional[Union[Iterable[str], str]] = None
|
| 103 |
+
) -> pd.DataFrame:
|
| 104 |
+
species_to_predict = cast(predict, to=list)
|
| 105 |
+
prediction_cols = []
|
| 106 |
if extra_metrics is None:
|
| 107 |
extra_metrics = []
|
| 108 |
else:
|
| 109 |
extra_metrics = cast(extra_metrics, to=list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
for species in species_to_predict:
|
| 111 |
message = f"Predicting for species: {species}"
|
| 112 |
print_err(message)
|
|
|
|
| 115 |
this_features = this_modelbox._input_cols
|
| 116 |
this_labels = this_modelbox._label_cols
|
| 117 |
this_prediction_input = (
|
| 118 |
+
df
|
| 119 |
.rename(columns={
|
| 120 |
"smiles": this_features[0],
|
| 121 |
})
|
|
|
|
| 131 |
).with_format("numpy")["__prediction__"].flatten()
|
| 132 |
print(prediction)
|
| 133 |
this_col = f"{species}: predicted MIC (µM)"
|
| 134 |
+
df[this_col] = np.power(10., -prediction) * 1e6
|
| 135 |
prediction_cols.append(this_col)
|
| 136 |
this_col = f"{species}: predicted MIC (µg / mL)"
|
| 137 |
+
df[this_col] = np.power(10., -prediction) * 1e3 * df["mwt"]
|
| 138 |
prediction_cols.append(this_col)
|
| 139 |
|
| 140 |
for extra_metric in extra_metrics:
|
|
|
|
| 154 |
)
|
| 155 |
.with_format("numpy")
|
| 156 |
)
|
| 157 |
+
df[this_col] = this_extra[this_extra.column_names[-1]]
|
| 158 |
+
|
| 159 |
+
return prediction_cols, df
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def predict_one(
|
| 163 |
+
strings: str,
|
| 164 |
+
input_representation: str = 'smiles',
|
| 165 |
+
predict: Union[Iterable[str], str] = 'smiles',
|
| 166 |
+
extra_metrics: Optional[Union[Iterable[str], str]] = None
|
| 167 |
+
):
|
| 168 |
+
prediction_df = convert_one(
|
| 169 |
+
strings=strings,
|
| 170 |
+
input_representation=input_representation,
|
| 171 |
+
output_representation=['id', 'pubchem_name', 'pubchem_id', 'smiles', 'inchikey', "mwt", "clogp"],
|
| 172 |
+
)
|
| 173 |
+
prediction_cols, prediction_df = _prediction_loop(
|
| 174 |
+
prediction_df,
|
| 175 |
+
predict=predict,
|
| 176 |
+
extra_metrics=extra_metrics,
|
| 177 |
+
)
|
| 178 |
return gr.DataFrame(
|
| 179 |
+
prediction_df[
|
| 180 |
+
['id', 'pubchem_name', 'pubchem_id']
|
| 181 |
+
+ prediction_cols
|
| 182 |
+
+ ['smiles', 'inchikey', "mwt", "clogp"]
|
| 183 |
+
],
|
| 184 |
visible=True
|
| 185 |
)
|
| 186 |
|
|
|
|
| 231 |
else:
|
| 232 |
extra_metrics = cast(extra_metrics, to=list)
|
| 233 |
|
| 234 |
+
if df.shape[0] > MAX_ROWS:
|
| 235 |
+
message = f"Truncating input to {MAX_ROWS} rows"
|
| 236 |
+
print_err(message)
|
| 237 |
+
gr.Info(message, duration=15)
|
| 238 |
+
df = df.iloc[:MAX_ROWS]
|
| 239 |
+
|
| 240 |
prediction_df = convert_file(
|
| 241 |
df,
|
| 242 |
column=column,
|
| 243 |
input_representation=input_representation,
|
| 244 |
output_representation=["id", "smiles", "inchikey", "mwt", "clogp"],
|
| 245 |
)
|
| 246 |
+
prediction_cols, prediction_df = _prediction_loop(
|
| 247 |
+
prediction_df,
|
| 248 |
+
predict=predict,
|
| 249 |
+
extra_metrics=extra_metrics,
|
| 250 |
+
)
|
| 251 |
+
main_cols = set(
|
| 252 |
+
['id', 'inchikey', 'smiles', "mwt", "clogp"]
|
| 253 |
+
+ [column]
|
| 254 |
+
+ prediction_cols
|
| 255 |
+
)
|
| 256 |
+
other_cols = [
|
| 257 |
+
col for col in prediction_df
|
| 258 |
+
if col not in main_cols
|
| 259 |
+
]
|
| 260 |
+
return prediction_df[
|
| 261 |
+
['id', 'inchikey']
|
| 262 |
+
+ [column]
|
| 263 |
+
+ prediction_cols + other_cols
|
| 264 |
+
+ ['smiles', "mwt", "clogp"]
|
| 265 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
def draw_one(
|
| 268 |
strings: Union[Iterable[str], str],
|
|
|
|
| 283 |
legends=["\n".join(items) for items in zip(*_ids.values())],
|
| 284 |
)
|
| 285 |
|
| 286 |
+
def log10_if_all_positive(df, col):
|
| 287 |
+
if np.all(df[col] > 0.):
|
| 288 |
+
df[col] = np.log10(df[col])
|
| 289 |
+
title = f"log10[ {col} ]"
|
| 290 |
+
else:
|
| 291 |
+
title = col
|
| 292 |
+
return title, df
|
| 293 |
|
| 294 |
+
|
| 295 |
+
def plot_x_vs_y(
|
| 296 |
df,
|
| 297 |
+
x: str,
|
| 298 |
+
y: str,
|
| 299 |
color: Optional[str] = None,
|
| 300 |
):
|
| 301 |
print_err(df.head())
|
| 302 |
+
y_title = y
|
| 303 |
+
cols = ["id", "inchikey", "smiles", "mwt", "clogp", x, y]
|
|
|
|
|
|
|
|
|
|
| 304 |
if color is not None and color not in cols:
|
| 305 |
cols.append(color)
|
| 306 |
cols = list(set(cols))
|
| 307 |
+
x_title, df = log10_if_all_positive(df, x)
|
| 308 |
+
y_title, df = log10_if_all_positive(df, y)
|
| 309 |
+
color_title, df = log10_if_all_positive(df, color)
|
|
|
|
| 310 |
|
| 311 |
return gr.ScatterPlot(
|
| 312 |
value=df[cols],
|
| 313 |
+
x=x,
|
| 314 |
+
y=y,
|
| 315 |
color=color,
|
| 316 |
x_title=x_title,
|
| 317 |
y_title=y_title,
|
|
|
|
| 321 |
)
|
| 322 |
|
| 323 |
|
| 324 |
+
def plot_pred_vs_observed(
|
| 325 |
+
df,
|
| 326 |
+
species: str,
|
| 327 |
+
observed: str,
|
| 328 |
+
color: Optional[str] = None,
|
| 329 |
+
):
|
| 330 |
+
print_err(df.head())
|
| 331 |
+
xcol = f"{species}: predicted MIC (µM)"
|
| 332 |
+
ycol = observed
|
| 333 |
+
return plot_x_vs_y(
|
| 334 |
+
df,
|
| 335 |
+
x=xcol,
|
| 336 |
+
y=ycol,
|
| 337 |
+
color=color,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
def download_table(
|
| 342 |
df: pd.DataFrame
|
| 343 |
) -> str:
|
| 344 |
df_hash = nm.hash(pd.util.hash_pandas_object(df).values)
|
| 345 |
+
filename = f"predicted-{df_hash}.csv"
|
| 346 |
df.to_csv(filename, index=False)
|
| 347 |
return gr.DownloadButton(value=filename, visible=True)
|
| 348 |
|
| 349 |
+
|
| 350 |
with gr.Blocks() as demo:
|
| 351 |
|
| 352 |
with open(HEADER_FILE, 'r') as f:
|
|
|
|
| 391 |
]),
|
| 392 |
list(MODEL_REPOS)[0],
|
| 393 |
list(EXTRA_METRICS)[:2],
|
| 394 |
+
], # cipro, ceftriaxone, cefiderocol, linezolid, gepotidacin
|
| 395 |
[
|
| 396 |
'\n'.join([
|
| 397 |
"C[C@H]1[C@H]([C@H](C[C@@H](O1)O[C@H]2C[C@@](CC3=C2C(=C4C(=C3O)C(=O)C5=C(C4=O)C(=CC=C5)OC)O)(C(=O)CO)O)N)O",
|
|
|
|
| 411 |
"COC1=CC(=CC(=C1OC)OC)CC2=CN=C(N=C2N)N",
|
| 412 |
"CC1=CC(=NO1)NS(=O)(=O)C2=CC=C(C=C2)N",
|
| 413 |
"C1[C@@H]([C@H]([C@@H]([C@H]([C@@H]1NC(=O)[C@H](CCN)O)O[C@@H]2[C@@H]([C@H]([C@@H]([C@H](O2)CO)O)N)O)O)O[C@@H]3[C@@H]([C@H]([C@@H]([C@H](O3)CN)O)O)O)N\nC1=CN=CC=C1C(=O)NN",
|
| 414 |
+
"C1=CN=CC=C1C(=O)NN ",
|
| 415 |
]),
|
| 416 |
list(MODEL_REPOS)[0],
|
| 417 |
list(EXTRA_METRICS)[:2],
|
|
|
|
| 433 |
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)CC4)N=C3",
|
| 434 |
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)[C@@H](C4)N)N=C3",
|
| 435 |
"CC1=C(OC2=CC=CC=C12)CN(C)C(=O)/C=C/C3=CC4=C(NC(=O)[C@H](CC4)[NH3+])N=C3.[Cl-]",
|
| 436 |
+
"C1=C(C(=O)NC(=O)N1)F",
|
| 437 |
+
"CCCCCCNC(=O)N1C=C(C(=O)NC1=O)F",
|
| 438 |
+
"C[C@@H]1OC[C@@H]2[C@@H](O1)[C@@H]([C@H]([C@@H](O2)O[C@H]3[C@H]4COC(=O)[C@@H]4[C@@H](C5=CC6=C(C=C35)OCO6)C7=CC(=C(C(=C7)OC)O)OC)O)O",
|
| 439 |
+
]),
|
| 440 |
+
list(MODEL_REPOS)[0],
|
| 441 |
+
list(EXTRA_METRICS)[:2],
|
| 442 |
+
], # Debio1452, Debio-1452-NH3, Fabimycin, 5-FU, Carmofur, Etoposide
|
| 443 |
+
[
|
| 444 |
+
'\n'.join([
|
| 445 |
+
"COC1=CC(=CC(=C1OC)OC)CC2=CN=C(N=C2N)N",
|
| 446 |
+
"CC(C)C1=CC=C(C=C1)CN2C=CC3=C2C=CC4=C3C(=NC(=N4)NC5CC5)N",
|
| 447 |
+
"C1=CC(=CC=C1CCC2=CNC3=C2C(=O)NC(=N3)N)C(=O)N[C@@H](CCC(=O)O)C(=O)O",
|
| 448 |
+
"CC1=C(C2=C(C=C1)N=C(NC2=O)N)SC3=CC=NC=C3",
|
| 449 |
+
"CN(CC1=CN=C2C(=N1)C(=NC(=N2)N)N)C3=CC=C(C=C3)C(=O)N[C@@H](CCC(=O)O)C(=O)O",
|
| 450 |
+
"CC1=NC2=C(C=C(C=C2)CN(C)C3=CC=C(S3)C(=O)N[C@@H](CCC(=O)O)C(=O)O)C(=O)N1",
|
| 451 |
+
]),
|
| 452 |
+
list(MODEL_REPOS)[0],
|
| 453 |
+
list(EXTRA_METRICS)[:2],
|
| 454 |
+
], # Trimethoprim, SCH79797, Pemetrexed, Nolatrexed, Methotrexate, Raltitrexed
|
| 455 |
+
[
|
| 456 |
+
'\n'.join([
|
| 457 |
+
"C[C@H]([C@@H](C(=O)NO)NC(=O)C1=CC=C(C=C1)C#CC2=CC=C(C=C2)CN3CCOCC3)O",
|
| 458 |
+
"CC(C)C1=CC=C(C=C1)CN2C=CC3=C2C=CC4=C3C(=NC(=N4)NC5CC5)N",
|
| 459 |
+
"C1=CC=C(C=C1)CNC2=NC(=NC3=CC=CC=C32)NCC4=CC=CC=C4",
|
| 460 |
+
"CC(C)(C)C1=CC=C(C=C1)C(=O)NC(=S)NC2=CC=C(C=C2)NC(=O)CCCCN(C)C",
|
| 461 |
+
"CCC1=C(C(=NC(=N1)N)N)C2=CC=C(C=C2)Cl",
|
| 462 |
+
"C1=CC(=CC=C1C(=O)N[C@@H](CCC(=O)O)C(=O)O)NCC2=CN=C3C(=N2)C(=NC(=N3)N)N",
|
| 463 |
]),
|
| 464 |
list(MODEL_REPOS)[0],
|
| 465 |
list(EXTRA_METRICS)[:2],
|
| 466 |
+
], # CHIR-090, SCH79797, DBeQ, Tenovin-6, Pyrimethamine, Aminopterin
|
| 467 |
|
| 468 |
],
|
| 469 |
example_labels=[
|
|
|
|
| 471 |
"Doxorubicin, Ampicillin, Amoxicillin, Meropenem, Tetracycline, Anhydrotetracycline",
|
| 472 |
"Halicin, Abaucin, Trimethoprim, Sulfamethoxazole, Amikacin, Isoniazid",
|
| 473 |
"Murepavadin, Vancomycin, Zosurabalpin, Plazomicin, Gentamicin, Rifampicin",
|
| 474 |
+
"Debio-1452, Debio-1452-NH3, Fabimycin, 5-FU, Carmofur, Etoposide",
|
| 475 |
+
"Trimethoprim, Pemetrexed, Nolatrexed, Methotrexate, Raltitrexed",
|
| 476 |
+
"CHIR-090, SCH79797, DBeQ, Tenovin-6, Pyrimethamine, Aminopterin"
|
| 477 |
],
|
| 478 |
inputs=[input_line, output_species_single, extra_metric],
|
| 479 |
cache_mode="eager",
|
|
|
|
| 517 |
outputs=download_single
|
| 518 |
)
|
| 519 |
|
| 520 |
+
with gr.Tab(f"Predict on structures from a file (max. {MAX_ROWS} rows, single species)"):
|
| 521 |
input_file = gr.File(
|
| 522 |
label="Upload a table of chemical compounds here",
|
| 523 |
file_types=[".xlsx", ".csv", ".tsv", ".txt"],
|
|
|
|
| 565 |
)
|
| 566 |
with gr.Row():
|
| 567 |
observed_col = gr.Dropdown(
|
| 568 |
+
label="Observed column (y-axis) for left plot",
|
| 569 |
choices=[],
|
| 570 |
value=None,
|
| 571 |
interactive=True,
|
| 572 |
visible=False,
|
| 573 |
)
|
| 574 |
color_col = gr.Dropdown(
|
| 575 |
+
label="Color for left plot",
|
| 576 |
+
choices=[],
|
| 577 |
+
value=None,
|
| 578 |
+
interactive=True,
|
| 579 |
+
visible=False,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
any_x_col = gr.Dropdown(
|
| 583 |
+
label="x-axis for right plot",
|
| 584 |
+
choices=[],
|
| 585 |
+
value=None,
|
| 586 |
+
interactive=True,
|
| 587 |
+
visible=False,
|
| 588 |
+
)
|
| 589 |
+
any_y_col = gr.Dropdown(
|
| 590 |
+
label="y-axis for right plot",
|
| 591 |
+
choices=[],
|
| 592 |
+
value=None,
|
| 593 |
+
interactive=True,
|
| 594 |
+
visible=False,
|
| 595 |
+
)
|
| 596 |
+
any_color_col = gr.Dropdown(
|
| 597 |
+
label="Color for right plot",
|
| 598 |
choices=[],
|
| 599 |
value=None,
|
| 600 |
interactive=True,
|
|
|
|
| 607 |
file_examples = gr.Examples(
|
| 608 |
examples=[
|
| 609 |
[
|
| 610 |
+
"example-data/stokes2020-eco.csv",
|
| 611 |
"SMILES",
|
| 612 |
"Klebsiella pneumoniae",
|
| 613 |
"Mean_Inhibition",
|
| 614 |
"Klebsiella pneumoniae: Doubtscore",
|
| 615 |
+
list(EXTRA_METRICS)[:3],
|
| 616 |
+
],
|
| 617 |
+
[
|
| 618 |
+
"example-data/liu23-abau.csv",
|
| 619 |
+
"SMILES",
|
| 620 |
+
"Klebsiella pneumoniae",
|
| 621 |
+
"Mean",
|
| 622 |
+
"Klebsiella pneumoniae: Doubtscore",
|
| 623 |
+
list(EXTRA_METRICS)[:3],
|
| 624 |
+
],
|
| 625 |
+
[
|
| 626 |
+
"example-data/wong24-sau-tox-5000.csv",
|
| 627 |
+
"SMILES",
|
| 628 |
+
"Klebsiella pneumoniae",
|
| 629 |
+
"Mean",
|
| 630 |
+
"Klebsiella pneumoniae: Doubtscore",
|
| 631 |
+
list(EXTRA_METRICS)[:3],
|
| 632 |
+
],
|
| 633 |
],
|
| 634 |
example_labels=[
|
| 635 |
+
"E. coli training data from Stokes J. et al., Cell, 2020",
|
| 636 |
+
"A. baumannii training data from Liu, 2023",
|
| 637 |
+
"S. aureus and toxicity training data from Wong, 2024",
|
| 638 |
],
|
| 639 |
inputs=[input_file, input_column, output_species, observed_col, color_col, extra_metric_file],
|
| 640 |
cache_mode="eager",
|
| 641 |
)
|
| 642 |
+
with gr.Row():
|
| 643 |
+
pred_vs_observed = gr.ScatterPlot(
|
| 644 |
+
label="Prediction vs observed",
|
| 645 |
+
x_title="Predicted MIC (µM)",
|
| 646 |
+
y_title="Observed",
|
| 647 |
+
visible=False,
|
| 648 |
+
height=600,
|
| 649 |
+
)
|
| 650 |
+
plot_any_vs_any = gr.ScatterPlot(
|
| 651 |
+
label="Any vs any",
|
| 652 |
+
visible=False,
|
| 653 |
+
height=600,
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
load_data_action = {
|
| 657 |
+
"fn": load_input_data,
|
| 658 |
+
"inputs": [input_file],
|
| 659 |
+
"outputs": [input_data, input_column]
|
| 660 |
+
}
|
| 661 |
|
| 662 |
file_examples.load_input_event.then(
|
| 663 |
+
**load_data_action,
|
|
|
|
|
|
|
| 664 |
)
|
| 665 |
input_file.upload(
|
| 666 |
+
**load_data_action,
|
|
|
|
|
|
|
| 667 |
)
|
| 668 |
+
go2_click_event = go_button2.click(
|
| 669 |
predict_file,
|
| 670 |
inputs=[
|
| 671 |
input_data,
|
|
|
|
| 681 |
download_table,
|
| 682 |
inputs=input_data,
|
| 683 |
outputs=download
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 684 |
).then(
|
| 685 |
lambda: gr.Button(visible=True),
|
| 686 |
+
outputs=[plot_button]
|
| 687 |
)
|
| 688 |
+
|
| 689 |
+
for dropdown in [observed_col, color_col, any_color_col, any_x_col, any_y_col]:
|
| 690 |
+
go2_click_event.then(
|
| 691 |
+
partial(get_dropdown_options, _type="number"),
|
| 692 |
+
inputs=[input_data],
|
| 693 |
+
outputs=[dropdown],
|
| 694 |
+
)
|
| 695 |
|
| 696 |
plot_button.click(
|
| 697 |
plot_pred_vs_observed,
|
|
|
|
| 701 |
observed_col,
|
| 702 |
color_col,
|
| 703 |
],
|
| 704 |
+
outputs=[pred_vs_observed],
|
| 705 |
+
).then(
|
| 706 |
+
plot_x_vs_y,
|
| 707 |
+
inputs=[
|
| 708 |
+
input_data,
|
| 709 |
+
any_x_col,
|
| 710 |
+
any_y_col,
|
| 711 |
+
any_color_col,
|
| 712 |
+
],
|
| 713 |
+
outputs=[plot_any_vs_any],
|
| 714 |
)
|
| 715 |
|
| 716 |
if __name__ == "__main__":
|
example-data/liu23-abau.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
example-data/{stokes2020-eco-1000.csv → stokes2020-eco.csv}
RENAMED
|
The diff for this file is too large to render.
See raw diff
|
|
|
example-data/wong24-sau-tox-5000.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|