from collections import deque from pathlib import Path import pickle import gradio as gr import matplotlib.pyplot as plt import matplotlib.colors as mcolors import black import cv2 import inspect import numpy as np import pandas as pd import io from jinja2 import Template from PIL import Image import sklearn from sklearn.linear_model import LogisticRegression from sklearn.svm import LinearSVC from sklearn.base import ClassifierMixin from sklearn.datasets import load_iris from sklearn.decomposition import PCA from sklearn.metrics import classification_report import traceback import yaml from util import * import logging logging.basicConfig( level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL) format="%(asctime)s [%(levelname)s] %(message)s", # log format ) logger = logging.getLogger("ELVIS") # TODO: # - support for session: load a previous session and continue from there def label2color(labels, cmap=None): ''' Parameters ---------- labels: a list of distinct strings ''' if (cmap is not None) and (cmap != ''): # sample the number of needed colors from user-specified color map color_deque = deque(plt.get_cmap(cmap, len(labels)).colors) elif len(labels) <= 10: #color_deque = deque(["red", "green", "blue", "yellow", "orange", "purple", "pink", "brown", "gray", "black"]) #color_deque = deque(mcolors.TABLEAU_COLORS.keys()) color_deque = deque([c.replace('tab:', '') for c in mcolors.TABLEAU_COLORS]) elif len(labels) <= 148: color_deque = deque(mcolors.CSS4_COLORS) elif len(labels) <= 949: color_deque = deque([c.replace('xkcd:', '') for c in mcolors.XKCD_COLORS]) else: # very unlikely color_deque = deque(plt.get_cmap('vridis', len(labels))) colors = [] for label in labels: print(label, color_deque) if label.lower() in color_deque: colors.append(label.lower()) color_deque.remove(label.lower()) else: colors.append(color_deque.popleft()) return colors def toydata(): points = [['Red', 0.12375, 0.8516666666666667], ['Red', 0.19, 0.8916666666666666], ['Red', 0.27375, 0.9233333333333333], ['Blue', 0.50625, 0.785], ['Blue', 0.38375, 0.6733333333333333], ['Blue', 0.28875, 0.595]] df = pd.DataFrame(points, columns=['label', 'F1', 'F2']) return df class CoordinateProjection2d: """ Project data on the two coordinates. """ def __init__(self, dim0=0, dim1=1, **kwargs): self.dims = [dim0, dim1] def fit(self, X): self.mean = X.mean(axis=0) return self def transform(self, X): self._check_dims(X) return X[:, self.dims] def fit_transform(self, X): self.fit(X) return self.transform(X) def inverse_transform(self, Z): X = np.ones((len(Z), 1)) * self.mean self._check_dims(X) X[:, self.dims] = Z return X def _check_dims(self, X): n_features = X.shape[1] if self.dims[0] >= n_features: raise ValueError(f"dim0={self.dims[0]} exceeds the number of features {n_features}") if self.dims[1] >= n_features: raise ValueError(f"dim1={self.dims[1]} exceeds the number of features {n_features}") class InteractiveDecisionBoundary: DATASET_FILE = "dataset.csv" MODEL_FILE = "model.pkl" FIGURE_BASENAME = "figure" CODE_FILE = "generated_code.py" EXPORT_CODE_TEMPLATE = "export_code_template.py.j2" def __init__(self, width, height): # initialized in draw_plot #self.canvas_width = -1 #self.canvas_height = -1 self.canvas_width = width self.canvas_height = height supported_classifier_names = yaml.safe_load(Path("model_imports.yaml").read_text()) self.classifiers = {k: v for k, v in get_sklearn_classifiers().items() if k in supported_classifier_names} self.model_class = LinearSVC self.model_args = "" self.model = self.model_class() self.dataloaders = get_sklearn_dataloaders() supported_embedder_names = yaml.safe_load(Path("embedder_imports.yaml").read_text()) self.embedders = { 'CoordinateProjection2d': CoordinateProjection2d, 'GaussianRandomProjection': sklearn.random_projection.GaussianRandomProjection, 'SparseRandomProjection': sklearn.random_projection.SparseRandomProjection, } module = getattr(sklearn, 'decomposition') for cls_name, cls in inspect.getmembers(module, inspect.isclass): if cls_name in supported_embedder_names: self.embedders[cls_name] = cls # normalizers self.normalizers = { 'None': None, 'MinMaxScaler': sklearn.preprocessing.MinMaxScaler, 'StandardScaler': sklearn.preprocessing.StandardScaler, } # data embedding and preprocessing values self.embedder_class = CoordinateProjection2d self.embedder_args = "" self.normalizer_class = None self.jitter_std = 0 # todo: support arbitrary number of classes and user-defined class labels #self.dataset = toydata() #iris = load_iris(as_frame=True) #self.dataset = pd.concat([iris.data, iris.target], axis=1) #self.dataset = self.dataset.rename(columns={'target': 'label'}) self.dataset = pd.DataFrame(columns=['target', 'F1', 'F2']) self.dataset_type = 'Draw2D' self.custom_selected = True # options self.num_dots = 200 self.dpi = 100 self.cmap = None self.precision = 2 # number of decimal places to show in datatable self.marker_size = 100 self.data_image = None self.boundary_image = None self._axis_topleft = (0, 0) self.figure_extension = ".svg" self.css =""" #my-button { height: 30px; font-size: 16px; } #rowheight { height: 90px; } .file-chooser { height: 150px; } .hidden-button { display: none; } .report-table { border: 0 !important; } .report-table tr, .report-table th, .report-table td, .report-table tbody, .report-table thead { border: 0 !important; padding: 6px 12px; text-align: center; }""" def _get_features(self): """Get the feature values from the current dataset, applying normalization if set.""" X = self.dataset.loc[:, self.dataset.columns != 'target'].values if len(X) == 0: raise ValueError("The dataset is empty or not properly formatted.") return X def _process_features(self, features): if self.normalizer_class is not None: normalizer = self.normalizer_class() features = normalizer.fit_transform(features) if self.jitter_std > 0: noise = np.random.normal(0, self.jitter_std, features.shape) features += noise return features def _embed_features(self, features, return_embedder=False): embedder = self.embedder_class(n_components=2, **parse_param_string(self.embedder_args)) features = embedder.fit_transform(features) if return_embedder: return features, embedder return features def _reset_data_processing_and_embedding(self): # Reset the values self.normalizer_class = None self.jitter_std = 0 self.embedder_class = CoordinateProjection2d self.embedder_args = "" def plot(self, decision_boundary=False, save_figure=False): ''' Plot data and decision boundary with matplotlib and return as PIL image. ''' logger.info("Initializing figure") fig = plt.figure(figsize=(self.canvas_width/100., self.canvas_height/100.0), dpi=self.dpi) # set entire figure to be the canvas to allow simple conversion of mouse # position to coordinates in the figure ax = fig.add_axes([0., 0., 1., 1.]) # ax.margins(x=0, y=0) # no padding in both directions if self.dataset_type == 'Draw2D': # draw canvas boundary #ax.scatter([0, 0, 1, 1], [0, 1, 0, 1], color='brown') # DO NOT CHANGE THE COLOR OF THE BOUNDARY # IT WILL BREAK THE ORIGIN COORDINATE DETECTION ax.plot([0, 0, 1, 1, 0], [0, 1, 1, 0, 0], color='black') for spine in ax.spines.values(): spine.set_color((0.1, 0.1, 0.1)) # TODO: allow showing x and y axes with ticks and labels if (self.dataset is not None and len(self.dataset) > 0): try: X = self._get_features() y = self.dataset.target.values logger.info("Data:\n" + str(X)) logger.info("Target:\n" + str(y)) # preprocess features X = self._process_features(X) # embed features to 2D for visualization Z, embedder = self._embed_features(X, return_embedder=True) #ax.set_title("Click to add points") labels = np.unique(y) colors = label2color(labels, cmap=self.cmap) logger.info("Classes:\n" + str(labels)) logger.info("Colors:\n" + str(colors)) l2c = dict(zip(labels, colors)) # scatter plots for data for l, label in enumerate(labels): #print('class', label) #ax.scatter(*zip(*self.dataset[self.dataset.label == label].features), color=label, label=label) subset = Z[y == label] ax.scatter(subset[:, 0], subset[:, 1], color=colors[l], label=label, s=self.marker_size) ax.legend() # plot the decision boundary if decision_boundary: model = self.model_class(**parse_param_string(self.model_args)) model.fit(X, y) self.model = model # plot decision boundary in the projected space # xx, yy = np.meshgrid(np.linspace(Z[:, 0].min(), 1, 100), np.linspace(0, 1, 100)) # Note: Should not apply normalization/jittering to meshgrid points if self.dataset_type == 'Draw2D': xx, yy = np.meshgrid(np.linspace(0, 1, self.num_dots), np.linspace(0, 1, self.num_dots)) else: xx, yy = np.meshgrid(np.linspace(Z[:, 0].min(), Z[:, 0].max(), self.num_dots), np.linspace(Z[:, 1].min(), Z[:, 1].max(), self.num_dots)) grid = np.c_[xx.ravel(), yy.ravel()] #scores = clf.decision_function(grid)[:, 1].reshape(xx.shape) #scores = clf.decision_function(grid).reshape(xx.shape) #ax.contour(xx, yy, scores)#, levels=[0], colors="black", linestyles="--") print('grid', grid) print('inverse', embedder.inverse_transform(grid)) preds = model.predict(embedder.inverse_transform(grid)).reshape(xx.shape) #print(preds.shape, xx.shape, yy.shape) ax.scatter(xx.ravel(), yy.ravel(), c=[l2c[l] for l in preds.ravel()], s=1, alpha=0.5) except Exception as e: print(traceback.format_exc()) #raise gr.Error(f"⚠️ {e}") gr.Info(f"⚠️ {e}") buf = io.BytesIO() ax.figure.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) plt.close(fig) buf.seek(0) img = Image.open(buf) if save_figure: ax.figure.savefig(f"{self.FIGURE_BASENAME}{self.figure_extension}") # detect axis pixel positions if self.dataset_type == 'Draw2D': array = np.array(img.convert("RGB")) bgr = cv2.cvtColor(array, cv2.COLOR_RGB2BGR) gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) black_mask = gray < 0.05 * 255 contours, _ = cv2.findContours(black_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # find the contour with the largest area max_area = 0 most_likely_topleft = 0, 0 for contour in contours: x, y, w, h = cv2.boundingRect(contour) area = w * h if w * h > max_area: max_area = area most_likely_topleft = x, y self._axis_topleft = most_likely_topleft else: self._axis_topleft = 0, 0 # TODO: add a save function for saving screenshot #img.save('image.png') return img def update_resolution(self, num_dots): self.num_dots = num_dots return self.plot(decision_boundary=True) def update_dpi(self, dpi): self.dpi = dpi return self.plot(decision_boundary=True) def update_cmap(self, cmap): self.cmap = cmap return self.plot(decision_boundary=True) def update_precision(self, precision): self.precision = precision data_table = gr.Dataframe( value=self.dataset.round(self.precision), visible=True, headers=list(self.dataset.columns), ) return data_table def update_marker_size(self, size): self.marker_size = size return self.plot(decision_boundary=True) def add_point(self, evt: gr.SelectData, label): ''' Mouse click to add a point. ''' if self.custom_selected: if self.dataset_type != 'Draw2D': self.dataset = pd.DataFrame(columns=['target', 'F1', 'F2']) self.dataset_type = 'Draw2D' shift_x, shift_y = self._axis_topleft # normalize clicked position to [0, 1] x = (evt.index[0] - shift_x) / self.canvas_width y = 1 - (evt.index[1] - shift_y) / self.canvas_height # flip y-axis to match matplotlib if 0 <= x <= 1 and 0 <= y <= 1: self.dataset.loc[len(self.dataset)] = [label, x, y] logger.info(f'clicked ({evt.index[0]}, {evt.index[1]}), mapped to ({x}, {y})') vis = self.plot() data_table = gr.Dataframe( value=self.dataset.round(self.precision), visible=True, headers=list(self.dataset.columns), ) return vis, data_table # train a model and show decision boundary def train(self): image = self.plot(decision_boundary=True) try: X = self.dataset.loc[:, self.dataset.columns != 'target'].values y = self.dataset.target.values pred = self.model.predict(X) df = pd.DataFrame(classification_report(y, pred, output_dict=True)).T summary = df.to_html(classes="report-table", float_format="%.2f") return image, gr.HTML(visible=True), "Classification report
" + summary except Exception as e: print(traceback.format_exc()) gr.Info(f"⚠️ {e}") return image, gr.HTML(visible=False), "" # clear data points and data preprocessing def clear(self): self.dataset = pd.DataFrame(columns=['target', 'F1', 'F2']) return self.plot(), gr.Dataframe(visible=False) def save(self): self.save_data() # save dataset def save_data(self): # TODO: allow user-specified filename self.dataset.to_csv(self.DATASET_FILE, index=False) logger.info(f"{self.DATASET_FILE} updated") return self.DATASET_FILE def update_model(self, classifier_name): self.model_class = self.classifiers[classifier_name] self.args_textbox.value = "" logger.info(f'Updated model to {self.model_class}') return "" def save_model(self): with open(self.MODEL_FILE, "wb") as f: pickle.dump(self.model, f) logger.info(f"{self.MODEL_FILE} updated") return self.MODEL_FILE def save_code(self): model_class = str(self.model_class.__name__) model_imports = yaml.safe_load(Path("model_imports.yaml").read_text()) if model_class not in model_imports: raise ValueError(f"Model {model_class} not found in model_imports.yaml") model_import_stmt = f"{model_imports[model_class]}" embedder_class = str(self.embedder_class.__name__) if embedder_class == "CoordinateProjection2d": embedder_import_stmt = f"\n\n\n{inspect.getsource(CoordinateProjection2d)}".rstrip() else: embedder_imports = yaml.safe_load(Path("embedder_imports.yaml").read_text()) if embedder_class not in embedder_imports: raise ValueError(f"Embedder {embedder_class} not found in embedder_imports.yaml") embedder_import_stmt = f"\n{embedder_imports[embedder_class]}" if self.normalizer_class is not None: normalizer_class = str(self.normalizer_class.__name__) normalizer_imports = yaml.safe_load(Path("normalizer_imports.yaml").read_text()) if normalizer_class not in normalizer_imports: raise ValueError(f"Normalizer {normalizer_class} not found in normalizer_imports.yaml") normalizer_import_stmt = f"\n{normalizer_imports[normalizer_class]}" else: normalizer_import_stmt = "" if self.dataset_type == 'Draw2D': x_min = 0 x_max = 1 y_min = 0 y_max = 1 else: x_min = "X_embedded[:, 0].min()" x_max = "X_embedded[:, 0].max()" y_min = "X_embedded[:, 1].min()" y_max = "X_embedded[:, 1].max()" model_params = parse_param_string(self.model_args) if len(model_params) == 0: model_params_text = "" else: model_params_text = "".join([f"\n\t\t{k}={repr(v)}," for k, v in model_params.items()]) + "\n\t" if embedder_class == "CoordinateProjection2d": embedder_args = {**parse_param_string(self.embedder_args)} else: embedder_args = {"n_components": 2, **parse_param_string(self.embedder_args)} embedder_args_text = "".join([f"\n\t\t{k}={repr(v)}," for k, v in embedder_args.items()]) + "\n\t" template = Template(Path(self.EXPORT_CODE_TEMPLATE).read_text()) variables = { 'model_import_statement': model_import_stmt, 'embedder_import_statement': embedder_import_stmt, "normalizer_import_statement": normalizer_import_stmt, 'dataset_file': self.DATASET_FILE, 'embedder_class': embedder_class, 'embedder_args': embedder_args_text, 'model_class': model_class, 'model_params': model_params_text, 'fig_width': self.canvas_width / 100, 'fig_height': self.canvas_height / 100, 'dpi': 100, 'num_dots': self.num_dots, 'x_min': x_min, 'x_max': x_max, 'y_min': y_min, 'y_max': y_max, "normalize": self.normalizer_class is not None, "normalizer_class": self.normalizer_class.__name__ if self.normalizer_class is not None else "", "jitter": self.jitter_std > 0, "jitter_scale": self.jitter_std, } rendered_code = template.render(variables) rendered_code = black.format_str(rendered_code, mode=black.FileMode()) Path(self.CODE_FILE).write_text(rendered_code) logger.info(f"{self.CODE_FILE} updated") return self.CODE_FILE def update_figure_extension(self, ext): self.figure_extension = ext print('updated figure extension:', self.figure_extension) def save_figure(self): self.plot(decision_boundary=True, save_figure=True) return f"{self.FIGURE_BASENAME}{self.figure_extension}" def update_args(self, model_args): self.model_args = model_args print('updated model_args:', self.model_args) def update_embedder(self, embedder): self.embedder_class = self.embedders[embedder] print('updated Embedder:', self.embedder_class) return self.plot() def update_embedder_args(self, embedder_args): self.embedder_args = embedder_args print('updated Embedder args:', self.embedder_args) return self.plot() def update_normalizer(self, normalizer): self.normalizer_class = self.normalizers[normalizer] print('updated Normalizer:', self.normalizer_class) data_table = gr.Dataframe( value=self.dataset[:100].round(self.precision), visible=True, headers=list(self.dataset.columns), ) return self.plot(), data_table def update_jittering(self, jitter_std): try: self.jitter_std = float(jitter_std) except ValueError: self.jitter_std = 0 print('updated Jittering std:', self.jitter_std) return self.plot() def handle_dataset_radio(self, type): if type == 'Draw2D': self.custom_selected = True self.dataset_type = "Draw2D" new_fields = gr.File(visible=False), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=False, value=0), gr.Dropdown(visible=False, value="CoordinateProjection2d"), gr.Textbox(visible=False), gr.Textbox(visible=True), gr.Button(visible=True), gr.Button(visible=False) elif type == 'Upload': self.dataset_type = "Upload" self.custom_selected = False new_fields = gr.File(visible=True), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value=0), gr.Dropdown(visible=True, value="CoordinateProjection2d"), gr.Textbox(visible=True), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False) elif type == 'sklearn': self.dataset_type = "sklearn" self.custom_selected = False new_fields = gr.File(visible=False), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value=0), gr.Dropdown(visible=True, value="CoordinateProjection2d"), gr.Textbox(visible=True), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False) else: # TODO: better error handling print('Error - unknown dataset type:', type) self._reset_data_processing_and_embedding() plot, data_table = self.clear() return plot, data_table, *new_fields def load_local_data_and_plot(self, filename): if filename is not None: self.dataset = read(filename) self.dataset.target = self.dataset.target.astype(str) self.dataset_type = 'Upload' logger.info(f'Loaded dataset from {filename}') vis = self.plot() #data_html = self.dataset.to_html(classes="report-table", float_format="%.2f") # TODO: need to make it explicit that this only shows first 100 points data_table = gr.Dataframe( value=self.dataset[:100].round(self.precision), visible=True, headers=list(self.dataset.columns) ) return vis, data_table def load_sklearn_data_and_plot(self, datasetname): if datasetname is not None and datasetname != "None": dataset = self.dataloaders[datasetname]() X = dataset.data y = dataset.target if hasattr(dataset, 'feature_names'): feature_names = dataset.feature_names else: feature_names = ['F{%d}' % i for i in range(len(X[0]))] if hasattr(dataset, 'target_names'): labels = dataset.target_names else: labels = ['C{%d}' % i for i in range(len(np.unique(y)))] y = np.array([labels[i] for i in y]) self.dataset = pd.DataFrame(X, columns=feature_names) self.dataset['target'] = y.astype(str) self.dataset_type = 'sklearn' logger.info(f'Loaded dataset {datasetname}') vis = self.plot() #data_html = self.dataset.to_html(classes="report-table", float_format="%.2f") # TODO: need to make it explicit that this only shows first 100 points data_table = gr.Dataframe( value=self.dataset[:100].round(self.precision), visible=True, headers=list(self.dataset.columns) ) return vis, data_table def launch(self): # build the Gradio interface with gr.Blocks(css=self.css) as demo: # app title gr.HTML("
Interactive Decision Boundary Visualizer
") # GUI elements and layout with gr.Row(): with gr.Column(scale=2): self.data_image = gr.Image( value=self.plot(), container=True, show_share_button=False, show_fullscreen_button=False, show_download_button=False, show_label=False, ) with gr.Column(scale=1): with gr.Tab("Dataset"): dataset_radio = gr.Radio( ["Draw2D", "Upload", "sklearn"], value="Draw2D", label="Dataset type", elem_id="rowheight", ) # upload data file_chooser = gr.File(label="Choose a file", visible=False, elem_classes="file-chooser") self.file_chooser = file_chooser # sklearn data dropdown menu sklearn_data_selector = gr.Dropdown( choices=self.dataloaders, label='Select dataset', value='None', visible=False, allow_custom_value=True, ) self.sklearn_data_selector = sklearn_data_selector # normalization normalizer_selector = gr.Dropdown( choices=self.normalizers, label='Select normalizer', value='None', visible=False, ) self.normalizer_selector = normalizer_selector # jittering jittering_textbox = gr.Textbox(label="Set jittering std", value="0", visible=False) self.jittering_textbox = jittering_textbox # embedder embedder_selector = gr.Dropdown( choices=self.embedders, label='Select embedder (only for plotting)', value='CoordinateProjection2d', visible=False, allow_custom_value=True, ) self.embedder_selector = embedder_selector embedder_args_textbox = gr.Textbox(label="Embedder arguments", visible=False) self.embedder_args_textbox = embedder_args_textbox # custom data label = gr.Radio(["Gray", "Orange", "Blue"], value="Gray", label="Choose point label", visible=True, elem_id="rowheight") self.label = label with gr.Row(): btn_clear = gr.Button("Clear", visible=True, elem_id="my-button") self.btn_clear = btn_clear btn_save = gr.Button("Save", visible=False, elem_id="my-button") self.btn_save = btn_save #data_html = gr.HTML(visible=True) data_table = gr.Dataframe(visible=False) # classifier selector with gr.Tab("Classifier"): # specify model model_selector = gr.Dropdown(choices=self.classifiers, #label='', #value='Select classifier', label='Select Classifier', value='LinearSVC', allow_custom_value=True) self.model_selector = model_selector # specify arguments args_textbox = gr.Textbox(label="Classifier arguments") self.args_textbox = args_textbox model_selector.change(fn=self.update_model, inputs=model_selector, outputs=args_textbox) btn_train = gr.Button("Train Model") classification_summary = gr.HTML(visible=False) with gr.Tab("Export"): # use hidden download button to generate files on the fly # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634 btn_export_data = gr.Button("Data") btn_export_data_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_data_hidden", elem_classes="hidden-button") btn_export_model = gr.Button('Model') btn_export_model_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_model_hidden", elem_classes="hidden-button") btn_export_code = gr.Button('Code') btn_export_code_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_code_hidden", elem_classes="hidden-button") with gr.Row(): btn_export_figure = gr.Button('Figure') btn_export_figure_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_figure_hidden", elem_classes="hidden-button") figure_extension_selector = gr.Dropdown(choices=['.svg', '.pdf', '.png', '.jpeg'], label="File extension", value=".svg") with gr.Tab("Options"): grid_resolution_slider = gr.Slider(minimum=100, maximum=1000, value=200, step=10, label="Decision boundary grid resolution") # image_dpi_slider = gr.Slider(minimum=100, maximum=1000, value=100, step=10, label="Image DPI") cmap_textbox = gr.Textbox(label="Colormap") precision_slider = gr.Slider(minimum=0, maximum=20, value=2, step=1, label="# decimal place in datatable") marker_size_slider = gr.Slider(minimum=0, maximum=200, value=100, step=5, label="Marker size") with gr.Tab("Usage"): gr.Markdown(''.join(open('usage.md', 'r').readlines())) # event handlers for GUI elements self.data_image.select(self.add_point, inputs=label, outputs=(self.data_image, data_table)) dataset_radio.change( fn=self.handle_dataset_radio, inputs=dataset_radio, outputs=( self.data_image, data_table, file_chooser, sklearn_data_selector, normalizer_selector, jittering_textbox, embedder_selector, embedder_args_textbox, label, btn_clear, btn_save ), ) # events for custom dataset btn_clear.click(fn=self.clear, outputs=(self.data_image, data_table)) btn_save.click(fn=self.save) # events for local dataset file_chooser.change(fn=self.load_local_data_and_plot, inputs=file_chooser, outputs=(self.data_image, data_table)) # events for sklearn dataset sklearn_data_selector.change(fn=self.load_sklearn_data_and_plot, inputs=sklearn_data_selector, outputs=(self.data_image, data_table)) embedder_selector.change(fn=self.update_embedder, inputs=embedder_selector, outputs=self.data_image) embedder_args_textbox.change( fn=self.update_embedder_args, inputs=embedder_args_textbox, outputs=self.data_image, ) normalizer_selector.change( fn=self.update_normalizer, inputs=normalizer_selector, outputs=(self.data_image, data_table), ) jittering_textbox.change( fn=self.update_jittering, inputs=jittering_textbox, outputs=self.data_image, ) btn_train.click(fn=self.update_args, inputs=args_textbox) btn_train.click(fn=self.train, outputs=(self.data_image, classification_summary, classification_summary)) # events for export # create files on the fly using hidden download buttons # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634 btn_export_data.click( fn=self.save_data, inputs=None, outputs=[btn_export_data_hidden] ).then( fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_data_hidden').click()" ) btn_export_model.click( fn=self.save_model, inputs=None, outputs=[btn_export_model_hidden] ).then( fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_model_hidden').click()" ) btn_export_code.click( fn=self.save_code, inputs=None, outputs=[btn_export_code_hidden] ).then( fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_code_hidden').click()" ) btn_export_figure.click( fn=self.save_figure, inputs=None, outputs=[btn_export_figure_hidden] ).then( fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_figure_hidden').click()" ) figure_extension_selector.change(self.update_figure_extension, inputs=figure_extension_selector) # events for options grid_resolution_slider.change(self.update_resolution, inputs=grid_resolution_slider, outputs=self.data_image) # image_dpi_slider.change(self.update_dpi, inputs=image_dpi_slider, outputs=self.data_image) cmap_textbox.submit(self.update_cmap, inputs=cmap_textbox, outputs=self.data_image) precision_slider.change(self.update_precision, inputs=precision_slider, outputs=data_table) marker_size_slider.change(self.update_marker_size, inputs=marker_size_slider, outputs=self.data_image) demo.launch() visualizer = InteractiveDecisionBoundary(width=1200, height=900) visualizer.launch()