Spaces:
Runtime error
Runtime error
| """Gradio demo for schemist.""" | |
| from typing import Iterable, List, Optional, Union | |
| import csv | |
| from io import TextIOWrapper | |
| import itertools | |
| import json | |
| import os | |
| import sys | |
| csv.field_size_limit(sys.maxsize) | |
| from carabiner import cast, print_err | |
| from carabiner.pd import read_table | |
| from duvida.autoclass import AutoModelBox | |
| import gradio as gr | |
| import nemony as nm | |
| import numpy as np | |
| import pandas as pd | |
| from rdkit.Chem import Draw, Mol | |
| from schemist.converting import ( | |
| _FROM_FUNCTIONS, | |
| convert_string_representation, | |
| _x2mol, | |
| ) | |
| from schemist.tables import converter | |
| import torch | |
| from duvida.stateless.config import config | |
| THEME = gr.themes.Default() | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CACHE = "./cache" | |
| MAX_ROWS = 1000 | |
| BATCH_SIZE = 16 | |
| HEADER_FILE = os.path.join("sources", "header.md") | |
| with open("repos.json", "r") as f: | |
| MODEL_REPOS = json.load(f) | |
| MODELBOXES = { | |
| key: AutoModelBox.from_pretrained(val, cache_dir=os.path.join(CACHE, "duvida")) | |
| for key, val in MODEL_REPOS.items() | |
| } | |
| [mb.to(DEVICE) for mb in MODELBOXES.values()] | |
| EXTRA_METRICS = { | |
| "log10(variance)": lambda modelbox, candidates: modelbox.prediction_variance(candidates=candidates, batch_size=BATCH_SIZE, cache=CACHE).map(lambda x: {modelbox._variance_key: torch.log10(x[modelbox._variance_key])}), | |
| "Tanimoto nearest neighbor to training data": lambda modelbox, candidates: modelbox.tanimoto_nn(candidates=candidates, batch_size=BATCH_SIZE), | |
| "Doubtscore": lambda modelbox, candidates: modelbox.doubtscore(candidates=candidates, cache=CACHE, batch_size=BATCH_SIZE).map(lambda x: {"doubtscore": torch.log10(x["doubtscore"])}), | |
| "Information sensitivity (approx.)": lambda modelbox, candidates: modelbox.information_sensitivity(candidates=candidates, batch_size=BATCH_SIZE, optimality_approximation=True, approximator="squared_jacobian", cache=CACHE).map(lambda x: {"information sensitivity": torch.log10(x["information sensitivity"])}), | |
| } | |
| with open(os.path.join("example-data", "examples.json"), "r") as f: | |
| EXAMPLES = json.load(f) | |
| def get_dropdown_options(df, _type = str): | |
| if _type == str: | |
| cols = list(df.select_dtypes(exclude=[np.number])) | |
| else: | |
| cols = list(df.select_dtypes([np.number])) | |
| non_none = [col for col in cols if col is not None] | |
| if len(cols) > 0: | |
| default_value = non_none[0] | |
| else: | |
| default_value = "" | |
| print_err(f"Dropdown default value is {default_value}") | |
| return gr.Dropdown( | |
| choices=cols, | |
| interactive=True, | |
| value=default_value, | |
| visible=True, | |
| allow_custom_value=True, | |
| ) | |
| def load_input_data(file: Union[TextIOWrapper, str], return_pd: bool = False) -> pd.DataFrame: | |
| file = file if isinstance(file, str) else file.name | |
| print_err(f"Loading {file}") | |
| df = read_table(file, nrows=MAX_ROWS) | |
| print_err(df.head()) | |
| if return_pd: | |
| return (df, gr.Dataframe(value=df, visible=True)), get_dropdown_options(df, str) | |
| else: | |
| return gr.Dataframe(value=df, visible=True), get_dropdown_options(df, str) | |
| def _clean_split_input(strings: str) -> List[str]: | |
| return [ | |
| s2.split(":")[-1].strip() | |
| for s in strings.split("\n") | |
| for s2 in s.split(",") | |
| ] | |
| def _convert_input( | |
| strings: str, | |
| input_representation: str = 'smiles', | |
| output_representation: Union[Iterable[str], str] = 'smiles' | |
| ) -> List[str]: | |
| strings = _clean_split_input(strings) | |
| converted = convert_string_representation( | |
| strings=strings, | |
| input_representation=input_representation, | |
| output_representation=output_representation, | |
| ) | |
| return {key: list(map(str, cast(val, to=list))) for key, val in converted.items()} | |
| def convert_one( | |
| strings: str, | |
| input_representation: str = 'smiles', | |
| output_representation: Union[Iterable[str], str] = 'smiles', | |
| ): | |
| output_representation = cast(output_representation, to=list) | |
| for rep in output_representation: | |
| message = f"Converting from {input_representation} to {rep}..." | |
| gr.Info(message, duration=3) | |
| df = pd.DataFrame({ | |
| input_representation: _clean_split_input(strings), | |
| }) | |
| return convert_file( | |
| df=df, | |
| column=input_representation, | |
| input_representation=input_representation, | |
| output_representation=output_representation, | |
| ) | |
| def _prediction_loop( | |
| df: pd.DataFrame, | |
| predict: Union[Iterable[str], str] = 'smiles', | |
| extra_metrics: Optional[Union[Iterable[str], str]] = None | |
| ) -> pd.DataFrame: | |
| species_to_predict = cast(predict, to=list) | |
| prediction_cols = [] | |
| if extra_metrics is None: | |
| extra_metrics = [] | |
| else: | |
| extra_metrics = cast(extra_metrics, to=list) | |
| for species in species_to_predict: | |
| message = f"Predicting for species: {species}" | |
| print_err(message) | |
| gr.Info(message, duration=3) | |
| this_modelbox = MODELBOXES[species] | |
| this_features = this_modelbox._input_cols | |
| this_labels = this_modelbox._label_cols | |
| this_prediction_input = ( | |
| df | |
| .rename(columns={ | |
| "smiles": this_features[0], | |
| }) | |
| .assign(**{label: np.nan for label in this_labels}) | |
| ) | |
| print(this_prediction_input) | |
| prediction = this_modelbox.predict( | |
| data=this_prediction_input, | |
| features=this_features, | |
| labels=this_labels, | |
| aggregator="mean", | |
| cache=CACHE, | |
| ).with_format("numpy")["__prediction__"].flatten() | |
| print(prediction) | |
| this_col = f"{species}: predicted MIC (µM)" | |
| df[this_col] = np.power(10., -prediction) * 1e6 | |
| prediction_cols.append(this_col) | |
| this_col = f"{species}: predicted MIC (µg / mL)" | |
| df[this_col] = np.power(10., -prediction) * 1e3 * df["mwt"] | |
| prediction_cols.append(this_col) | |
| for extra_metric in extra_metrics: | |
| message = f"Calculating {extra_metric} for species: {species}" | |
| print_err(message) | |
| gr.Info(message, duration=10) | |
| # this_modelbox._input_training_data = this_modelbox._input_training_data.remove_columns([this_modelbox._in_key]) | |
| this_col = f"{species}: {extra_metric}" | |
| prediction_cols.append(this_col) | |
| print(">>>", this_modelbox._input_training_data) | |
| print(">>>", this_modelbox._input_training_data.format) | |
| print(">>>", this_modelbox._in_key, this_modelbox._out_key) | |
| this_extra = ( | |
| EXTRA_METRICS[extra_metric]( | |
| this_modelbox, | |
| this_prediction_input, | |
| ) | |
| .with_format("numpy") | |
| ) | |
| df[this_col] = this_extra[this_extra.column_names[-1]] | |
| return prediction_cols, df | |
| def predict_one( | |
| strings: str, | |
| input_representation: str = 'smiles', | |
| predict: Union[Iterable[str], str] = 'smiles', | |
| extra_metrics: Optional[Union[Iterable[str], str]] = None, | |
| return_pd: bool = False | |
| ): | |
| prediction_df = convert_one( | |
| strings=strings, | |
| input_representation=input_representation, | |
| output_representation=['id', 'pubchem_name', 'pubchem_id', 'smiles', 'inchikey', "mwt", "clogp"], | |
| ) | |
| prediction_cols, prediction_df = _prediction_loop( | |
| prediction_df, | |
| predict=predict, | |
| extra_metrics=extra_metrics, | |
| ) | |
| df = prediction_df[ | |
| ['id', 'pubchem_name', 'pubchem_id'] | |
| + prediction_cols | |
| + ['smiles', 'inchikey', "mwt", "clogp"] | |
| ] | |
| gradio_opts = { | |
| "label": "Predictions", | |
| "value": df, | |
| "pinned_columns": 3, | |
| "visible": True, | |
| "wrap": True, | |
| "column_widths": [120] * 3 + [250] * (prediction_df.shape[1] - 3), | |
| } | |
| if return_pd: | |
| return df, gr.DataFrame(**gradio_opts) | |
| else: | |
| return gr.DataFrame(**gradio_opts) | |
| def convert_file( | |
| df: pd.DataFrame, | |
| column: str = 'smiles', | |
| input_representation: str = 'smiles', | |
| output_representation: Union[str, Iterable[str]] = 'smiles' | |
| ): | |
| output_representation = cast(output_representation, to=list) | |
| message = f"Converting from {input_representation} to {', '.join(output_representation)}..." | |
| gr.Info(message, duration=5) | |
| print_err(message) | |
| print_err(df.head()) | |
| errors, df = converter( | |
| df=df, | |
| column=column, | |
| input_representation=input_representation, | |
| output_representation=output_representation, | |
| ) | |
| df = df[ | |
| output_representation + | |
| [col for col in df if col not in output_representation] | |
| ] | |
| all_err = sum(err for key, err in errors.items()) | |
| message = ( | |
| f"Converted {df.shape[0]} molecules from " | |
| f"{input_representation} to {output_representation} " | |
| f"with {all_err} errors!" | |
| ) | |
| print_err(message) | |
| gr.Info(message, duration=5) | |
| return df | |
| def predict_file( | |
| df: pd.DataFrame, | |
| column: str = 'smiles', | |
| input_representation: str = 'smiles', | |
| predict: str = 'smiles', | |
| predict2: Optional[str] = None, | |
| extra_metrics: Optional[Union[Iterable[str], str]] = None, | |
| return_pd: bool = False | |
| ): | |
| predict = cast(predict, to=list) | |
| if predict2 is not None and predict2 in MODELBOXES: | |
| predict += cast(predict2, to=list) | |
| if extra_metrics is None: | |
| extra_metrics = [] | |
| else: | |
| extra_metrics = cast(extra_metrics, to=list) | |
| if df.shape[0] > MAX_ROWS: | |
| message = f"Truncating input to {MAX_ROWS} rows" | |
| print_err(message) | |
| gr.Info(message, duration=15) | |
| df = df.iloc[:MAX_ROWS] | |
| prediction_df = convert_file( | |
| df, | |
| column=column, | |
| input_representation=input_representation, | |
| output_representation=["id", "smiles", "inchikey", "mwt", "clogp"], | |
| ) | |
| prediction_cols, prediction_df = _prediction_loop( | |
| prediction_df, | |
| predict=predict, | |
| extra_metrics=extra_metrics, | |
| ) | |
| left_cols = ['id', 'inchikey'] | |
| end_cols = ["smiles", "mwt", "clogp"] | |
| main_cols = set( | |
| left_cols | |
| + end_cols | |
| + [column] | |
| + prediction_cols | |
| ) | |
| other_cols = list(set(prediction_df) - main_cols) | |
| return_cols = ( | |
| left_cols | |
| + [column] | |
| + prediction_cols | |
| + other_cols | |
| + end_cols | |
| ) | |
| deduplicated_cols = [] | |
| for col in return_cols: | |
| if not col in deduplicated_cols: | |
| deduplicated_cols.append(col) | |
| prediction_df = prediction_df[deduplicated_cols] | |
| plot_dropdown = get_dropdown_options(prediction_df, _type="number") | |
| plot_dropdown = (plot_dropdown,) * 5 | |
| gradio_opts = { | |
| "label": "Predictions", | |
| "value": prediction_df, | |
| "pinned_columns": 3, | |
| "visible": True, | |
| "wrap": True, | |
| "column_widths": [120] * 3 + [250] * (prediction_df.shape[1] - 3), | |
| } | |
| if return_pd: | |
| return ((prediction_df, gr.Dataframe(**gradio_opts)),) + (plot_dropdown,) | |
| else: | |
| return (gr.Dataframe(**gradio_opts),) + (plot_dropdown,) | |
| def draw_one( | |
| df, | |
| smiles_col: str = "smiles", | |
| legends: Optional[Union[str, Iterable[str]]] = None | |
| ): | |
| if legends is None: | |
| legends = ["inchikey", "id", "pubchem_name"] | |
| else: | |
| legends = [] | |
| message = f"Drawing {df.shape[0]} molecules..." | |
| gr.Info(message, duration=2) | |
| _ids = {col: df[col].tolist() for col in legends} | |
| mols = cast(_x2mol(df[smiles_col], "smiles"), to=list) | |
| if isinstance(mols, Mol): | |
| mols = [mols] | |
| return Draw.MolsToGridImage( | |
| mols, | |
| molsPerRow=min(5, len(mols)), | |
| subImgSize=(600, 600), | |
| legends=[ | |
| "\n".join( | |
| _x if _x is not None else "" | |
| for _x in items | |
| ) for items in zip(*_ids.values()) | |
| ], | |
| ) | |
| def log10_if_all_positive(df, col): | |
| if np.all(df[col] > 0.): | |
| df[col] = np.log10(df[col]) | |
| title = f"log10[ {col} ]" | |
| else: | |
| title = col | |
| return title, df | |
| def plot_x_vs_y( | |
| df, | |
| x: str, | |
| y: str, | |
| color: Optional[str] = None, | |
| ): | |
| message = f"Plotting x={x}, y={y}, color={color}..." | |
| gr.Info(message, duration=10) | |
| print_err(df.head()) | |
| y_title = y | |
| cols = ["id", "inchikey", "smiles", "mwt", "clogp", x, y] | |
| if color is not None and color not in cols: | |
| cols.append(color) | |
| cols = list(set(cols)) | |
| x_title, df = log10_if_all_positive(df, x) | |
| y_title, df = log10_if_all_positive(df, y) | |
| color_title, df = log10_if_all_positive(df, color) | |
| return gr.ScatterPlot( | |
| value=df[cols], | |
| x=x, | |
| y=y, | |
| color=color, | |
| x_title=x_title, | |
| y_title=y_title, | |
| color_title=color_title, | |
| tooltip="all", | |
| visible=True, | |
| ) | |
| def plot_pred_vs_observed( | |
| df, | |
| species: str, | |
| observed: str, | |
| color: Optional[str] = None, | |
| ): | |
| print_err(df.head()) | |
| xcol = f"{species}: predicted MIC (µM)" | |
| ycol = observed | |
| return plot_x_vs_y( | |
| df, | |
| x=xcol, | |
| y=ycol, | |
| color=color, | |
| ) | |
| def download_table( | |
| df: pd.DataFrame | |
| ) -> str: | |
| df_hash = nm.hash(pd.util.hash_pandas_object(df).values) | |
| filename = os.path.join(CACHE, "downloads", f"predicted-{df_hash}.csv") | |
| if not os.path.exists(os.path.dirname(filename)): | |
| os.makedirs(os.path.dirname(filename)) | |
| df.to_csv(filename, index=False) | |
| return gr.DownloadButton(value=filename, visible=True) | |
| def _predict_then_draw_then_download( | |
| strings: str, | |
| input_representation: str = 'smiles', | |
| predict: Union[Iterable[str], str] = 'smiles', | |
| extra_metrics: Optional[Union[Iterable[str], str]] = None, | |
| smiles_col: str = "smiles", | |
| legends: Optional[Union[str, Iterable[str]]] = None | |
| ): | |
| df, gr_df = predict_one( | |
| strings=strings, | |
| input_representation=input_representation, | |
| predict=predict, | |
| extra_metrics=extra_metrics, | |
| return_pd=True, | |
| ) | |
| img = draw_one( | |
| df, | |
| smiles_col="smiles", | |
| ) | |
| return gr_df, img, download_table(df) | |
| def _load_then_predict_then_download_then_reveal_plot( | |
| file: str, | |
| column: str = 'smiles', | |
| input_representation: str = 'smiles', | |
| predict: str = 'smiles', | |
| predict2: Optional[str] = "", | |
| extra_metrics: Optional[Union[Iterable[str], str]] = None | |
| ): | |
| (df, df_gr), col_opts = load_input_data( | |
| file, | |
| return_pd=True, | |
| ) | |
| (df, df_gr), plot_opts = predict_file( | |
| df, | |
| column=column, | |
| input_representation=input_representation, | |
| predict=predict, | |
| predict2=None if predict2 == "" else predict2, | |
| extra_metrics=extra_metrics, | |
| return_pd=True, | |
| ) | |
| print_err(df.head()) | |
| return ( | |
| df_gr, | |
| download_table(df), | |
| ) + plot_opts | |
| def _initial_setup(): | |
| """Set up blocks. | |
| """ | |
| print_err(f"Duvida config is {config}") | |
| print_err(f"Default torch device is {DEVICE}") | |
| line_inputs = { | |
| "format": gr.Dropdown( | |
| label="Input string format", | |
| choices=list(_FROM_FUNCTIONS), | |
| value="smiles", | |
| interactive=True, | |
| ), | |
| "species": gr.CheckboxGroup( | |
| label="Species for prediction", | |
| choices=list(MODEL_REPOS), | |
| value=list(MODEL_REPOS)[:1], | |
| interactive=True, | |
| ), | |
| "extras": gr.CheckboxGroup( | |
| label="Extra metrics (Doubtscore & Information Sensitivity can increase calculation time to a couple of minutes!)", | |
| choices=list(EXTRA_METRICS), | |
| value=list(EXTRA_METRICS)[:2], | |
| interactive=True, | |
| ), | |
| "strings": gr.Textbox( | |
| label="Input", | |
| placeholder="Paste your molecule here, one per line.", | |
| lines=2, | |
| interactive=True, | |
| submit_btn=True, | |
| ), | |
| } | |
| output_line = gr.DataFrame( | |
| label="Predictions (scroll left and right)", | |
| interactive=False, | |
| visible=True, | |
| ) | |
| download_single = gr.DownloadButton( | |
| label="Download predictions", | |
| visible=True, | |
| ) | |
| drawing = gr.Image(label="Chemical structures") | |
| file_inputs = { | |
| "file": gr.File( | |
| label="Upload a table of chemical compounds here", | |
| file_types=[".xlsx", ".csv", ".tsv", ".txt"], | |
| ), | |
| "column": gr.Dropdown( | |
| label="Input column name", | |
| choices=[], | |
| allow_custom_value=True, | |
| visible=True, | |
| interactive=True, | |
| ), | |
| "format": gr.Dropdown( | |
| label="Input string format", | |
| choices=list(_FROM_FUNCTIONS), | |
| value="smiles", | |
| interactive=True, | |
| visible=True, | |
| ), | |
| "species": [ | |
| gr.Dropdown( | |
| label="Species 1 for prediction", | |
| choices=list(MODEL_REPOS), | |
| value=list(MODEL_REPOS)[0], | |
| interactive=True, | |
| allow_custom_value=True, | |
| ), | |
| gr.Dropdown( | |
| label="Species 2 for prediction", | |
| choices=list(MODEL_REPOS), | |
| value=None, | |
| interactive=True, | |
| allow_custom_value=True, | |
| ), | |
| ], | |
| "extras": gr.CheckboxGroup( | |
| label="Extra metrics (Information Sensitivity can increase calculation time)", | |
| choices=list(EXTRA_METRICS), | |
| value=list(EXTRA_METRICS)[:2], | |
| interactive=True, | |
| ), | |
| } | |
| input_dataframe = gr.Dataframe( | |
| label="Input data", | |
| max_height=500, | |
| visible=True, | |
| interactive=False, | |
| show_fullscreen_button=True, | |
| show_search="filter", | |
| max_chars=45, | |
| ) | |
| download = gr.DownloadButton( | |
| label="Download predictions", | |
| visible=False, | |
| ) | |
| plot_button = gr.Button( | |
| value="Plot!", | |
| visible=False, | |
| ) | |
| left_plot_inputs = { | |
| "observed": gr.Dropdown( | |
| label="Observed column (y-axis) for left plot", | |
| choices=[], | |
| value=None, | |
| interactive=True, | |
| visible=True, | |
| allow_custom_value=True, | |
| ), | |
| "color": gr.Dropdown( | |
| label="Color for left plot", | |
| choices=[], | |
| value=None, | |
| interactive=True, | |
| visible=True, | |
| allow_custom_value=True, | |
| ) | |
| } | |
| right_plot_inputs = { | |
| "x": gr.Dropdown( | |
| label="x-axis for right plot", | |
| choices=[], | |
| value=None, | |
| interactive=True, | |
| visible=True, | |
| allow_custom_value=True, | |
| ), | |
| "y": gr.Dropdown( | |
| label="y-axis for right plot", | |
| choices=[], | |
| value=None, | |
| interactive=True, | |
| visible=True, | |
| allow_custom_value=True, | |
| ), | |
| "color": gr.Dropdown( | |
| label="Color for right plot", | |
| choices=[], | |
| value=None, | |
| interactive=True, | |
| visible=True, | |
| allow_custom_value=True, | |
| ) | |
| } | |
| plots = { | |
| "left": gr.ScatterPlot( | |
| height=500, | |
| visible=False, | |
| ), | |
| "right": gr.ScatterPlot( | |
| height=500, | |
| visible=False, | |
| ), | |
| } | |
| return ( | |
| line_inputs, | |
| output_line, | |
| download_single, | |
| drawing, | |
| file_inputs, | |
| input_dataframe, | |
| download, | |
| plot_button, | |
| left_plot_inputs, | |
| right_plot_inputs, | |
| plots, | |
| ) | |
| if __name__ == "__main__": | |
| ( | |
| line_inputs, | |
| output_line, | |
| download_single, | |
| drawing, | |
| file_inputs, | |
| input_dataframe, | |
| download, | |
| plot_button, | |
| left_plot_inputs, | |
| right_plot_inputs, | |
| plots, | |
| ) = _initial_setup() | |
| with gr.Blocks(theme=THEME) as demo: | |
| with open(HEADER_FILE, 'r') as f: | |
| header_md = f.read() | |
| gr.Markdown(header_md) | |
| with gr.Tab(label="Paste one per line"): | |
| examples = gr.Examples( | |
| examples=[ | |
| [ | |
| "\n".join(eg["strings"]), | |
| "smiles", | |
| eg["species"], | |
| list(EXTRA_METRICS)[:3], | |
| ] | |
| for eg in EXAMPLES["line input examples"] | |
| ], | |
| example_labels=[ | |
| eg["label"] for eg in EXAMPLES["line input examples"] | |
| ], | |
| examples_per_page=100, | |
| inputs=[ | |
| line_inputs["strings"], | |
| line_inputs["format"], | |
| line_inputs["species"], | |
| line_inputs["extras"], | |
| ], | |
| fn=_predict_then_draw_then_download, | |
| outputs=[ | |
| output_line, | |
| drawing, | |
| download_single, | |
| ], | |
| cache_examples=True, | |
| cache_mode="eager", | |
| ) | |
| for val in line_inputs.values(): | |
| val.render() | |
| # with gr.Row(): | |
| output_line.render() | |
| download_single.render() | |
| drawing.render() | |
| line_inputs["strings"].submit( | |
| fn=_predict_then_draw_then_download, | |
| inputs=[ | |
| line_inputs["strings"], | |
| line_inputs["format"], | |
| line_inputs["species"], | |
| line_inputs["extras"], | |
| ], | |
| outputs=[ | |
| output_line, | |
| drawing, | |
| download_single, | |
| ], | |
| ) | |
| with gr.Tab(f"Predict on structures from a file (max. {MAX_ROWS} rows, ≤ 2 species)"): | |
| plot_dropdowns = list(itertools.chain( | |
| left_plot_inputs.values(), | |
| right_plot_inputs.values(), | |
| )) | |
| file_examples = gr.Examples( | |
| examples=[ | |
| [ | |
| eg["file"], | |
| eg["column"], | |
| "smiles", | |
| eg["species"], | |
| "", | |
| list(EXTRA_METRICS)[:3], | |
| ] for eg in EXAMPLES["file examples"] | |
| ], | |
| example_labels=[ | |
| eg["label"] for eg in EXAMPLES["file examples"] | |
| ], | |
| fn=_load_then_predict_then_download_then_reveal_plot, | |
| inputs=[ | |
| file_inputs["file"], | |
| file_inputs["column"], | |
| file_inputs["format"], | |
| *file_inputs["species"], | |
| file_inputs["extras"], | |
| ], | |
| outputs=[ | |
| input_dataframe, | |
| download, | |
| *plot_dropdowns, | |
| ], | |
| cache_examples=True, ## appears to cause CSV load error | |
| cache_mode="eager", | |
| ) | |
| file_inputs["file"].render() | |
| with gr.Row(): | |
| for key in ("column", "format"): | |
| file_inputs[key].render() | |
| with gr.Row(): | |
| for item in file_inputs["species"]: | |
| item.render() | |
| file_inputs["extras"].render() | |
| go_button2 = gr.Button(value="Predict!") | |
| input_dataframe.render() | |
| download.render() | |
| with gr.Row(): | |
| for val in left_plot_inputs.values(): | |
| val.render() | |
| with gr.Row(): | |
| for val in right_plot_inputs.values(): | |
| val.render() | |
| plot_button.render() | |
| with gr.Row(): | |
| for val in plots.values(): | |
| val.render() | |
| file_inputs["file"].upload( | |
| fn=load_input_data, | |
| inputs=file_inputs["file"], | |
| outputs=[ | |
| input_dataframe, | |
| file_inputs["column"], | |
| ], | |
| ) | |
| go2_click_event = go_button2.click( | |
| _load_then_predict_then_download_then_reveal_plot, | |
| inputs=[ | |
| file_inputs["file"], | |
| file_inputs["column"], | |
| file_inputs["format"], | |
| *file_inputs["species"], | |
| file_inputs["extras"], | |
| ], | |
| outputs=[ | |
| input_dataframe, | |
| download, | |
| *plot_dropdowns, | |
| ], | |
| scroll_to_output=True, | |
| ).then( | |
| lambda: gr.Button(visible=True), | |
| outputs=[plot_button], | |
| js=True, | |
| ) | |
| file_examples.load_input_event.then( | |
| lambda: gr.Button(visible=True), | |
| outputs=[plot_button], | |
| js=True, | |
| ) | |
| plot_button.click( | |
| plot_pred_vs_observed, | |
| inputs=[ | |
| input_dataframe, | |
| file_inputs["species"][0], | |
| left_plot_inputs["observed"], | |
| left_plot_inputs["color"], | |
| ], | |
| outputs=[plots["left"]], | |
| scroll_to_output=True, | |
| ).then( | |
| plot_x_vs_y, | |
| inputs=[ | |
| input_dataframe, | |
| right_plot_inputs["x"], | |
| right_plot_inputs["y"], | |
| right_plot_inputs["color"], | |
| ], | |
| outputs=[plots["right"]], | |
| ) | |
| demo.queue() | |
| demo.launch(share=True) |