Spaces:
Running
Running
| from collections import deque | |
| from pathlib import Path | |
| import pickle | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import matplotlib.colors as mcolors | |
| import black | |
| import cv2 | |
| import inspect | |
| import numpy as np | |
| import pandas as pd | |
| import io | |
| from jinja2 import Template | |
| from PIL import Image | |
| import sklearn | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.svm import LinearSVC | |
| from sklearn.base import ClassifierMixin | |
| from sklearn.datasets import load_iris | |
| from sklearn.decomposition import PCA | |
| from sklearn.metrics import classification_report | |
| import traceback | |
| import yaml | |
| from util import * | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL) | |
| format="%(asctime)s [%(levelname)s] %(message)s", # log format | |
| ) | |
| logger = logging.getLogger("ELVIS") | |
| # TODO: | |
| # - support for session: load a previous session and continue from there | |
| def label2color(labels, cmap=None): | |
| ''' | |
| Parameters | |
| ---------- | |
| labels: a list of distinct strings | |
| ''' | |
| if (cmap is not None) and (cmap != ''): | |
| # sample the number of needed colors from user-specified color map | |
| color_deque = deque(plt.get_cmap(cmap, len(labels)).colors) | |
| elif len(labels) <= 10: | |
| #color_deque = deque(["red", "green", "blue", "yellow", "orange", "purple", "pink", "brown", "gray", "black"]) | |
| #color_deque = deque(mcolors.TABLEAU_COLORS.keys()) | |
| color_deque = deque([c.replace('tab:', '') for c in mcolors.TABLEAU_COLORS]) | |
| elif len(labels) <= 148: | |
| color_deque = deque(mcolors.CSS4_COLORS) | |
| elif len(labels) <= 949: | |
| color_deque = deque([c.replace('xkcd:', '') for c in mcolors.XKCD_COLORS]) | |
| else: # very unlikely | |
| color_deque = deque(plt.get_cmap('vridis', len(labels))) | |
| colors = [] | |
| for label in labels: | |
| print(label, color_deque) | |
| if label.lower() in color_deque: | |
| colors.append(label.lower()) | |
| color_deque.remove(label.lower()) | |
| else: | |
| colors.append(color_deque.popleft()) | |
| return colors | |
| def toydata(): | |
| points = [['Red', 0.12375, 0.8516666666666667], | |
| ['Red', 0.19, 0.8916666666666666], | |
| ['Red', 0.27375, 0.9233333333333333], | |
| ['Blue', 0.50625, 0.785], | |
| ['Blue', 0.38375, 0.6733333333333333], | |
| ['Blue', 0.28875, 0.595]] | |
| df = pd.DataFrame(points, columns=['label', 'F1', 'F2']) | |
| return df | |
| class CoordinateProjection2d: | |
| """ | |
| Project data on the two coordinates. | |
| """ | |
| def __init__(self, dim0=0, dim1=1, **kwargs): | |
| self.dims = [dim0, dim1] | |
| def fit(self, X): | |
| self.mean = X.mean(axis=0) | |
| return self | |
| def transform(self, X): | |
| self._check_dims(X) | |
| return X[:, self.dims] | |
| def fit_transform(self, X): | |
| self.fit(X) | |
| return self.transform(X) | |
| def inverse_transform(self, Z): | |
| X = np.ones((len(Z), 1)) * self.mean | |
| self._check_dims(X) | |
| X[:, self.dims] = Z | |
| return X | |
| def _check_dims(self, X): | |
| n_features = X.shape[1] | |
| if self.dims[0] >= n_features: | |
| raise ValueError(f"dim0={self.dims[0]} exceeds the number of features {n_features}") | |
| if self.dims[1] >= n_features: | |
| raise ValueError(f"dim1={self.dims[1]} exceeds the number of features {n_features}") | |
| class InteractiveDecisionBoundary: | |
| DATASET_FILE = "dataset.csv" | |
| MODEL_FILE = "model.pkl" | |
| FIGURE_BASENAME = "figure" | |
| CODE_FILE = "generated_code.py" | |
| EXPORT_CODE_TEMPLATE = "export_code_template.py.j2" | |
| def __init__(self, width, height): | |
| # initialized in draw_plot | |
| #self.canvas_width = -1 | |
| #self.canvas_height = -1 | |
| self.canvas_width = width | |
| self.canvas_height = height | |
| supported_classifier_names = yaml.safe_load(Path("model_imports.yaml").read_text()) | |
| self.classifiers = {k: v for k, v in get_sklearn_classifiers().items() if k in supported_classifier_names} | |
| self.model_class = LinearSVC | |
| self.model_args = "" | |
| self.model = self.model_class() | |
| self.dataloaders = get_sklearn_dataloaders() | |
| supported_embedder_names = yaml.safe_load(Path("embedder_imports.yaml").read_text()) | |
| self.embedders = { | |
| 'CoordinateProjection2d': CoordinateProjection2d, | |
| 'GaussianRandomProjection': sklearn.random_projection.GaussianRandomProjection, | |
| 'SparseRandomProjection': sklearn.random_projection.SparseRandomProjection, | |
| } | |
| module = getattr(sklearn, 'decomposition') | |
| for cls_name, cls in inspect.getmembers(module, inspect.isclass): | |
| if cls_name in supported_embedder_names: | |
| self.embedders[cls_name] = cls | |
| # normalizers | |
| self.normalizers = { | |
| 'None': None, | |
| 'MinMaxScaler': sklearn.preprocessing.MinMaxScaler, | |
| 'StandardScaler': sklearn.preprocessing.StandardScaler, | |
| } | |
| # data embedding and preprocessing values | |
| self.embedder_class = CoordinateProjection2d | |
| self.embedder_args = "" | |
| self.normalizer_class = None | |
| self.jitter_std = 0 | |
| # todo: support arbitrary number of classes and user-defined class labels | |
| #self.dataset = toydata() | |
| #iris = load_iris(as_frame=True) | |
| #self.dataset = pd.concat([iris.data, iris.target], axis=1) | |
| #self.dataset = self.dataset.rename(columns={'target': 'label'}) | |
| self.dataset = pd.DataFrame(columns=['target', 'F1', 'F2']) | |
| self.dataset_type = 'Draw2D' | |
| self.custom_selected = True | |
| # options | |
| self.num_dots = 200 | |
| self.dpi = 100 | |
| self.cmap = None | |
| self.precision = 2 # number of decimal places to show in datatable | |
| self.marker_size = 100 | |
| self.data_image = None | |
| self.boundary_image = None | |
| self._axis_topleft = (0, 0) | |
| self.figure_extension = ".svg" | |
| self.css =""" | |
| #my-button { | |
| height: 30px; | |
| font-size: 16px; | |
| } | |
| #rowheight { | |
| height: 90px; | |
| } | |
| .file-chooser { | |
| height: 150px; | |
| } | |
| .hidden-button { | |
| display: none; | |
| } | |
| .report-table { | |
| border: 0 !important; | |
| } | |
| .report-table tr, .report-table th, .report-table td, .report-table tbody, .report-table thead { | |
| border: 0 !important; | |
| padding: 6px 12px; | |
| text-align: center; | |
| }""" | |
| def _get_features(self): | |
| """Get the feature values from the current dataset, applying normalization if set.""" | |
| X = self.dataset.loc[:, self.dataset.columns != 'target'].values | |
| if len(X) == 0: | |
| raise ValueError("The dataset is empty or not properly formatted.") | |
| return X | |
| def _process_features(self, features): | |
| if self.normalizer_class is not None: | |
| normalizer = self.normalizer_class() | |
| features = normalizer.fit_transform(features) | |
| if self.jitter_std > 0: | |
| noise = np.random.normal(0, self.jitter_std, features.shape) | |
| features += noise | |
| return features | |
| def _embed_features(self, features, return_embedder=False): | |
| embedder = self.embedder_class(n_components=2, **parse_param_string(self.embedder_args)) | |
| features = embedder.fit_transform(features) | |
| if return_embedder: | |
| return features, embedder | |
| return features | |
| def _reset_data_processing_and_embedding(self): | |
| # Reset the values | |
| self.normalizer_class = None | |
| self.jitter_std = 0 | |
| self.embedder_class = CoordinateProjection2d | |
| self.embedder_args = "" | |
| def plot(self, decision_boundary=False, save_figure=False): | |
| ''' | |
| Plot data and decision boundary with matplotlib and return as PIL image. | |
| ''' | |
| logger.info("Initializing figure") | |
| fig = plt.figure(figsize=(self.canvas_width/100., self.canvas_height/100.0), dpi=self.dpi) | |
| # set entire figure to be the canvas to allow simple conversion of mouse | |
| # position to coordinates in the figure | |
| ax = fig.add_axes([0., 0., 1., 1.]) # | |
| ax.margins(x=0, y=0) # no padding in both directions | |
| if self.dataset_type == 'Draw2D': | |
| # draw canvas boundary | |
| #ax.scatter([0, 0, 1, 1], [0, 1, 0, 1], color='brown') | |
| # DO NOT CHANGE THE COLOR OF THE BOUNDARY | |
| # IT WILL BREAK THE ORIGIN COORDINATE DETECTION | |
| ax.plot([0, 0, 1, 1, 0], [0, 1, 1, 0, 0], color='black') | |
| for spine in ax.spines.values(): | |
| spine.set_color((0.1, 0.1, 0.1)) | |
| # TODO: allow showing x and y axes with ticks and labels | |
| if (self.dataset is not None and len(self.dataset) > 0): | |
| try: | |
| X = self._get_features() | |
| y = self.dataset.target.values | |
| logger.info("Data:\n" + str(X)) | |
| logger.info("Target:\n" + str(y)) | |
| # preprocess features | |
| X = self._process_features(X) | |
| # embed features to 2D for visualization | |
| Z, embedder = self._embed_features(X, return_embedder=True) | |
| #ax.set_title("Click to add points") | |
| labels = np.unique(y) | |
| colors = label2color(labels, cmap=self.cmap) | |
| logger.info("Classes:\n" + str(labels)) | |
| logger.info("Colors:\n" + str(colors)) | |
| l2c = dict(zip(labels, colors)) | |
| # scatter plots for data | |
| for l, label in enumerate(labels): | |
| #print('class', label) | |
| #ax.scatter(*zip(*self.dataset[self.dataset.label == label].features), color=label, label=label) | |
| subset = Z[y == label] | |
| ax.scatter(subset[:, 0], subset[:, 1], color=colors[l], label=label, s=self.marker_size) | |
| ax.legend() | |
| # plot the decision boundary | |
| if decision_boundary: | |
| model = self.model_class(**parse_param_string(self.model_args)) | |
| model.fit(X, y) | |
| self.model = model | |
| # plot decision boundary in the projected space | |
| # xx, yy = np.meshgrid(np.linspace(Z[:, 0].min(), 1, 100), np.linspace(0, 1, 100)) | |
| # Note: Should not apply normalization/jittering to meshgrid points | |
| if self.dataset_type == 'Draw2D': | |
| xx, yy = np.meshgrid(np.linspace(0, 1, self.num_dots), | |
| np.linspace(0, 1, self.num_dots)) | |
| else: | |
| xx, yy = np.meshgrid(np.linspace(Z[:, 0].min(), Z[:, 0].max(), self.num_dots), | |
| np.linspace(Z[:, 1].min(), Z[:, 1].max(), self.num_dots)) | |
| grid = np.c_[xx.ravel(), yy.ravel()] | |
| #scores = clf.decision_function(grid)[:, 1].reshape(xx.shape) | |
| #scores = clf.decision_function(grid).reshape(xx.shape) | |
| #ax.contour(xx, yy, scores)#, levels=[0], colors="black", linestyles="--") | |
| print('grid', grid) | |
| print('inverse', embedder.inverse_transform(grid)) | |
| preds = model.predict(embedder.inverse_transform(grid)).reshape(xx.shape) | |
| #print(preds.shape, xx.shape, yy.shape) | |
| ax.scatter(xx.ravel(), yy.ravel(), c=[l2c[l] for l in preds.ravel()], s=1, alpha=0.5) | |
| except Exception as e: | |
| print(traceback.format_exc()) | |
| #raise gr.Error(f"⚠️ {e}") | |
| gr.Info(f"⚠️ {e}") | |
| buf = io.BytesIO() | |
| ax.figure.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| if save_figure: | |
| ax.figure.savefig(f"{self.FIGURE_BASENAME}{self.figure_extension}") | |
| # detect axis pixel positions | |
| if self.dataset_type == 'Draw2D': | |
| array = np.array(img.convert("RGB")) | |
| bgr = cv2.cvtColor(array, cv2.COLOR_RGB2BGR) | |
| gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) | |
| black_mask = gray < 0.05 * 255 | |
| contours, _ = cv2.findContours(black_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| # find the contour with the largest area | |
| max_area = 0 | |
| most_likely_topleft = 0, 0 | |
| for contour in contours: | |
| x, y, w, h = cv2.boundingRect(contour) | |
| area = w * h | |
| if w * h > max_area: | |
| max_area = area | |
| most_likely_topleft = x, y | |
| self._axis_topleft = most_likely_topleft | |
| else: | |
| self._axis_topleft = 0, 0 | |
| # TODO: add a save function for saving screenshot | |
| #img.save('image.png') | |
| return img | |
| def update_resolution(self, num_dots): | |
| self.num_dots = num_dots | |
| return self.plot(decision_boundary=True) | |
| def update_dpi(self, dpi): | |
| self.dpi = dpi | |
| return self.plot(decision_boundary=True) | |
| def update_cmap(self, cmap): | |
| self.cmap = cmap | |
| return self.plot(decision_boundary=True) | |
| def update_precision(self, precision): | |
| self.precision = precision | |
| data_table = gr.Dataframe( | |
| value=self.dataset.round(self.precision), | |
| visible=True, | |
| headers=list(self.dataset.columns), | |
| ) | |
| return data_table | |
| def update_marker_size(self, size): | |
| self.marker_size = size | |
| return self.plot(decision_boundary=True) | |
| def add_point(self, evt: gr.SelectData, label): | |
| ''' | |
| Mouse click to add a point. | |
| ''' | |
| if self.custom_selected: | |
| if self.dataset_type != 'Draw2D': | |
| self.dataset = pd.DataFrame(columns=['target', 'F1', 'F2']) | |
| self.dataset_type = 'Draw2D' | |
| shift_x, shift_y = self._axis_topleft | |
| # normalize clicked position to [0, 1] | |
| x = (evt.index[0] - shift_x) / self.canvas_width | |
| y = 1 - (evt.index[1] - shift_y) / self.canvas_height # flip y-axis to match matplotlib | |
| if 0 <= x <= 1 and 0 <= y <= 1: | |
| self.dataset.loc[len(self.dataset)] = [label, x, y] | |
| logger.info(f'clicked ({evt.index[0]}, {evt.index[1]}), mapped to ({x}, {y})') | |
| vis = self.plot() | |
| data_table = gr.Dataframe( | |
| value=self.dataset.round(self.precision), | |
| visible=True, | |
| headers=list(self.dataset.columns), | |
| ) | |
| return vis, data_table | |
| # train a model and show decision boundary | |
| def train(self): | |
| image = self.plot(decision_boundary=True) | |
| try: | |
| X = self.dataset.loc[:, self.dataset.columns != 'target'].values | |
| y = self.dataset.target.values | |
| pred = self.model.predict(X) | |
| df = pd.DataFrame(classification_report(y, pred, output_dict=True)).T | |
| summary = df.to_html(classes="report-table", float_format="%.2f") | |
| return image, gr.HTML(visible=True), "<b>Classification report</b><br>" + summary | |
| except Exception as e: | |
| print(traceback.format_exc()) | |
| gr.Info(f"⚠️ {e}") | |
| return image, gr.HTML(visible=False), "" | |
| # clear data points and data preprocessing | |
| def clear(self): | |
| self.dataset = pd.DataFrame(columns=['target', 'F1', 'F2']) | |
| return self.plot(), gr.Dataframe(visible=False) | |
| def save(self): | |
| self.save_data() | |
| # save dataset | |
| def save_data(self): | |
| # TODO: allow user-specified filename | |
| self.dataset.to_csv(self.DATASET_FILE, index=False) | |
| logger.info(f"{self.DATASET_FILE} updated") | |
| return self.DATASET_FILE | |
| def update_model(self, classifier_name): | |
| self.model_class = self.classifiers[classifier_name] | |
| self.args_textbox.value = "" | |
| logger.info(f'Updated model to {self.model_class}') | |
| return "" | |
| def save_model(self): | |
| with open(self.MODEL_FILE, "wb") as f: | |
| pickle.dump(self.model, f) | |
| logger.info(f"{self.MODEL_FILE} updated") | |
| return self.MODEL_FILE | |
| def save_code(self): | |
| model_class = str(self.model_class.__name__) | |
| model_imports = yaml.safe_load(Path("model_imports.yaml").read_text()) | |
| if model_class not in model_imports: | |
| raise ValueError(f"Model {model_class} not found in model_imports.yaml") | |
| model_import_stmt = f"{model_imports[model_class]}" | |
| embedder_class = str(self.embedder_class.__name__) | |
| if embedder_class == "CoordinateProjection2d": | |
| embedder_import_stmt = f"\n\n\n{inspect.getsource(CoordinateProjection2d)}".rstrip() | |
| else: | |
| embedder_imports = yaml.safe_load(Path("embedder_imports.yaml").read_text()) | |
| if embedder_class not in embedder_imports: | |
| raise ValueError(f"Embedder {embedder_class} not found in embedder_imports.yaml") | |
| embedder_import_stmt = f"\n{embedder_imports[embedder_class]}" | |
| if self.normalizer_class is not None: | |
| normalizer_class = str(self.normalizer_class.__name__) | |
| normalizer_imports = yaml.safe_load(Path("normalizer_imports.yaml").read_text()) | |
| if normalizer_class not in normalizer_imports: | |
| raise ValueError(f"Normalizer {normalizer_class} not found in normalizer_imports.yaml") | |
| normalizer_import_stmt = f"\n{normalizer_imports[normalizer_class]}" | |
| else: | |
| normalizer_import_stmt = "" | |
| if self.dataset_type == 'Draw2D': | |
| x_min = 0 | |
| x_max = 1 | |
| y_min = 0 | |
| y_max = 1 | |
| else: | |
| x_min = "X_embedded[:, 0].min()" | |
| x_max = "X_embedded[:, 0].max()" | |
| y_min = "X_embedded[:, 1].min()" | |
| y_max = "X_embedded[:, 1].max()" | |
| model_params = parse_param_string(self.model_args) | |
| if len(model_params) == 0: | |
| model_params_text = "" | |
| else: | |
| model_params_text = "".join([f"\n\t\t{k}={repr(v)}," for k, v in model_params.items()]) + "\n\t" | |
| if embedder_class == "CoordinateProjection2d": | |
| embedder_args = {**parse_param_string(self.embedder_args)} | |
| else: | |
| embedder_args = {"n_components": 2, **parse_param_string(self.embedder_args)} | |
| embedder_args_text = "".join([f"\n\t\t{k}={repr(v)}," for k, v in embedder_args.items()]) + "\n\t" | |
| template = Template(Path(self.EXPORT_CODE_TEMPLATE).read_text()) | |
| variables = { | |
| 'model_import_statement': model_import_stmt, | |
| 'embedder_import_statement': embedder_import_stmt, | |
| "normalizer_import_statement": normalizer_import_stmt, | |
| 'dataset_file': self.DATASET_FILE, | |
| 'embedder_class': embedder_class, | |
| 'embedder_args': embedder_args_text, | |
| 'model_class': model_class, | |
| 'model_params': model_params_text, | |
| 'fig_width': self.canvas_width / 100, | |
| 'fig_height': self.canvas_height / 100, | |
| 'dpi': 100, | |
| 'num_dots': self.num_dots, | |
| 'x_min': x_min, | |
| 'x_max': x_max, | |
| 'y_min': y_min, | |
| 'y_max': y_max, | |
| "normalize": self.normalizer_class is not None, | |
| "normalizer_class": self.normalizer_class.__name__ if self.normalizer_class is not None else "", | |
| "jitter": self.jitter_std > 0, | |
| "jitter_scale": self.jitter_std, | |
| } | |
| rendered_code = template.render(variables) | |
| rendered_code = black.format_str(rendered_code, mode=black.FileMode()) | |
| Path(self.CODE_FILE).write_text(rendered_code) | |
| logger.info(f"{self.CODE_FILE} updated") | |
| return self.CODE_FILE | |
| def update_figure_extension(self, ext): | |
| self.figure_extension = ext | |
| print('updated figure extension:', self.figure_extension) | |
| def save_figure(self): | |
| self.plot(decision_boundary=True, save_figure=True) | |
| return f"{self.FIGURE_BASENAME}{self.figure_extension}" | |
| def update_args(self, model_args): | |
| self.model_args = model_args | |
| print('updated model_args:', self.model_args) | |
| def update_embedder(self, embedder): | |
| self.embedder_class = self.embedders[embedder] | |
| print('updated Embedder:', self.embedder_class) | |
| return self.plot() | |
| def update_embedder_args(self, embedder_args): | |
| self.embedder_args = embedder_args | |
| print('updated Embedder args:', self.embedder_args) | |
| return self.plot() | |
| def update_normalizer(self, normalizer): | |
| self.normalizer_class = self.normalizers[normalizer] | |
| print('updated Normalizer:', self.normalizer_class) | |
| data_table = gr.Dataframe( | |
| value=self.dataset[:100].round(self.precision), | |
| visible=True, | |
| headers=list(self.dataset.columns), | |
| ) | |
| return self.plot(), data_table | |
| def update_jittering(self, jitter_std): | |
| try: | |
| self.jitter_std = float(jitter_std) | |
| except ValueError: | |
| self.jitter_std = 0 | |
| print('updated Jittering std:', self.jitter_std) | |
| return self.plot() | |
| def handle_dataset_radio(self, type): | |
| if type == 'Draw2D': | |
| self.custom_selected = True | |
| self.dataset_type = "Draw2D" | |
| new_fields = gr.File(visible=False), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=False, value=0), gr.Dropdown(visible=False, value="CoordinateProjection2d"), gr.Textbox(visible=False), gr.Textbox(visible=True), gr.Button(visible=True), gr.Button(visible=False) | |
| elif type == 'Upload': | |
| self.dataset_type = "Upload" | |
| self.custom_selected = False | |
| new_fields = gr.File(visible=True), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value=0), gr.Dropdown(visible=True, value="CoordinateProjection2d"), gr.Textbox(visible=True), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False) | |
| elif type == 'sklearn': | |
| self.dataset_type = "sklearn" | |
| self.custom_selected = False | |
| new_fields = gr.File(visible=False), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value=0), gr.Dropdown(visible=True, value="CoordinateProjection2d"), gr.Textbox(visible=True), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False) | |
| else: | |
| # TODO: better error handling | |
| print('Error - unknown dataset type:', type) | |
| self._reset_data_processing_and_embedding() | |
| plot, data_table = self.clear() | |
| return plot, data_table, *new_fields | |
| def load_local_data_and_plot(self, filename): | |
| if filename is not None: | |
| self.dataset = read(filename) | |
| self.dataset.target = self.dataset.target.astype(str) | |
| self.dataset_type = 'Upload' | |
| logger.info(f'Loaded dataset from {filename}') | |
| vis = self.plot() | |
| #data_html = self.dataset.to_html(classes="report-table", float_format="%.2f") | |
| # TODO: need to make it explicit that this only shows first 100 points | |
| data_table = gr.Dataframe( | |
| value=self.dataset[:100].round(self.precision), | |
| visible=True, | |
| headers=list(self.dataset.columns) | |
| ) | |
| return vis, data_table | |
| def load_sklearn_data_and_plot(self, datasetname): | |
| if datasetname is not None and datasetname != "None": | |
| dataset = self.dataloaders[datasetname]() | |
| X = dataset.data | |
| y = dataset.target | |
| if hasattr(dataset, 'feature_names'): | |
| feature_names = dataset.feature_names | |
| else: | |
| feature_names = ['F{%d}' % i for i in range(len(X[0]))] | |
| if hasattr(dataset, 'target_names'): | |
| labels = dataset.target_names | |
| else: | |
| labels = ['C{%d}' % i for i in range(len(np.unique(y)))] | |
| y = np.array([labels[i] for i in y]) | |
| self.dataset = pd.DataFrame(X, columns=feature_names) | |
| self.dataset['target'] = y.astype(str) | |
| self.dataset_type = 'sklearn' | |
| logger.info(f'Loaded dataset {datasetname}') | |
| vis = self.plot() | |
| #data_html = self.dataset.to_html(classes="report-table", float_format="%.2f") | |
| # TODO: need to make it explicit that this only shows first 100 points | |
| data_table = gr.Dataframe( | |
| value=self.dataset[:100].round(self.precision), visible=True, | |
| headers=list(self.dataset.columns) | |
| ) | |
| return vis, data_table | |
| def launch(self): | |
| # build the Gradio interface | |
| with gr.Blocks(css=self.css) as demo: | |
| # app title | |
| gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Interactive Decision Boundary Visualizer</div>") | |
| # GUI elements and layout | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| self.data_image = gr.Image( | |
| value=self.plot(), | |
| container=True, | |
| show_share_button=False, | |
| show_fullscreen_button=False, | |
| show_download_button=False, | |
| show_label=False, | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Tab("Dataset"): | |
| dataset_radio = gr.Radio( | |
| ["Draw2D", "Upload", "sklearn"], | |
| value="Draw2D", | |
| label="Dataset type", | |
| elem_id="rowheight", | |
| ) | |
| # upload data | |
| file_chooser = gr.File(label="Choose a file", visible=False, elem_classes="file-chooser") | |
| self.file_chooser = file_chooser | |
| # sklearn data dropdown menu | |
| sklearn_data_selector = gr.Dropdown( | |
| choices=self.dataloaders, | |
| label='Select dataset', | |
| value='None', | |
| visible=False, | |
| allow_custom_value=True, | |
| ) | |
| self.sklearn_data_selector = sklearn_data_selector | |
| # normalization | |
| normalizer_selector = gr.Dropdown( | |
| choices=self.normalizers, | |
| label='Select normalizer', | |
| value='None', | |
| visible=False, | |
| ) | |
| self.normalizer_selector = normalizer_selector | |
| # jittering | |
| jittering_textbox = gr.Textbox(label="Set jittering std", value="0", visible=False) | |
| self.jittering_textbox = jittering_textbox | |
| # embedder | |
| embedder_selector = gr.Dropdown( | |
| choices=self.embedders, | |
| label='Select embedder (only for plotting)', | |
| value='CoordinateProjection2d', | |
| visible=False, | |
| allow_custom_value=True, | |
| ) | |
| self.embedder_selector = embedder_selector | |
| embedder_args_textbox = gr.Textbox(label="Embedder arguments", visible=False) | |
| self.embedder_args_textbox = embedder_args_textbox | |
| # custom data | |
| label = gr.Radio(["Gray", "Orange", "Blue"], value="Gray", label="Choose point label", visible=True, elem_id="rowheight") | |
| self.label = label | |
| with gr.Row(): | |
| btn_clear = gr.Button("Clear", visible=True, elem_id="my-button") | |
| self.btn_clear = btn_clear | |
| btn_save = gr.Button("Save", visible=False, elem_id="my-button") | |
| self.btn_save = btn_save | |
| #data_html = gr.HTML(visible=True) | |
| data_table = gr.Dataframe(visible=False) | |
| # classifier selector | |
| with gr.Tab("Classifier"): | |
| # specify model | |
| model_selector = gr.Dropdown(choices=self.classifiers, | |
| #label='', | |
| #value='Select classifier', | |
| label='Select Classifier', | |
| value='LinearSVC', | |
| allow_custom_value=True) | |
| self.model_selector = model_selector | |
| # specify arguments | |
| args_textbox = gr.Textbox(label="Classifier arguments") | |
| self.args_textbox = args_textbox | |
| model_selector.change(fn=self.update_model, inputs=model_selector, outputs=args_textbox) | |
| btn_train = gr.Button("Train Model") | |
| classification_summary = gr.HTML(visible=False) | |
| with gr.Tab("Export"): | |
| # use hidden download button to generate files on the fly | |
| # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634 | |
| btn_export_data = gr.Button("Data") | |
| btn_export_data_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_data_hidden", elem_classes="hidden-button") | |
| btn_export_model = gr.Button('Model') | |
| btn_export_model_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_model_hidden", elem_classes="hidden-button") | |
| btn_export_code = gr.Button('Code') | |
| btn_export_code_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_code_hidden", elem_classes="hidden-button") | |
| with gr.Row(): | |
| btn_export_figure = gr.Button('Figure') | |
| btn_export_figure_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_figure_hidden", elem_classes="hidden-button") | |
| figure_extension_selector = gr.Dropdown(choices=['.svg', '.pdf', '.png', '.jpeg'], label="File extension", value=".svg") | |
| with gr.Tab("Options"): | |
| grid_resolution_slider = gr.Slider(minimum=100, maximum=1000, value=200, step=10, label="Decision boundary grid resolution") | |
| # image_dpi_slider = gr.Slider(minimum=100, maximum=1000, value=100, step=10, label="Image DPI") | |
| cmap_textbox = gr.Textbox(label="Colormap") | |
| precision_slider = gr.Slider(minimum=0, maximum=20, value=2, step=1, label="# decimal place in datatable") | |
| marker_size_slider = gr.Slider(minimum=0, maximum=200, value=100, step=5, label="Marker size") | |
| with gr.Tab("Usage"): | |
| gr.Markdown(''.join(open('usage.md', 'r').readlines())) | |
| # event handlers for GUI elements | |
| self.data_image.select(self.add_point, inputs=label, | |
| outputs=(self.data_image, data_table)) | |
| dataset_radio.change( | |
| fn=self.handle_dataset_radio, | |
| inputs=dataset_radio, | |
| outputs=( | |
| self.data_image, data_table, file_chooser, sklearn_data_selector, normalizer_selector, jittering_textbox, embedder_selector, embedder_args_textbox, label, btn_clear, btn_save | |
| ), | |
| ) | |
| # events for custom dataset | |
| btn_clear.click(fn=self.clear, outputs=(self.data_image, data_table)) | |
| btn_save.click(fn=self.save) | |
| # events for local dataset | |
| file_chooser.change(fn=self.load_local_data_and_plot, | |
| inputs=file_chooser, | |
| outputs=(self.data_image, data_table)) | |
| # events for sklearn dataset | |
| sklearn_data_selector.change(fn=self.load_sklearn_data_and_plot, | |
| inputs=sklearn_data_selector, | |
| outputs=(self.data_image, data_table)) | |
| embedder_selector.change(fn=self.update_embedder, | |
| inputs=embedder_selector, | |
| outputs=self.data_image) | |
| embedder_args_textbox.change( | |
| fn=self.update_embedder_args, | |
| inputs=embedder_args_textbox, | |
| outputs=self.data_image, | |
| ) | |
| normalizer_selector.change( | |
| fn=self.update_normalizer, | |
| inputs=normalizer_selector, | |
| outputs=(self.data_image, data_table), | |
| ) | |
| jittering_textbox.change( | |
| fn=self.update_jittering, | |
| inputs=jittering_textbox, | |
| outputs=self.data_image, | |
| ) | |
| btn_train.click(fn=self.update_args, inputs=args_textbox) | |
| btn_train.click(fn=self.train, outputs=(self.data_image, classification_summary, classification_summary)) | |
| # events for export | |
| # create files on the fly using hidden download buttons | |
| # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634 | |
| btn_export_data.click( | |
| fn=self.save_data, | |
| inputs=None, | |
| outputs=[btn_export_data_hidden] | |
| ).then( | |
| fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_data_hidden').click()" | |
| ) | |
| btn_export_model.click( | |
| fn=self.save_model, | |
| inputs=None, | |
| outputs=[btn_export_model_hidden] | |
| ).then( | |
| fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_model_hidden').click()" | |
| ) | |
| btn_export_code.click( | |
| fn=self.save_code, | |
| inputs=None, | |
| outputs=[btn_export_code_hidden] | |
| ).then( | |
| fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_code_hidden').click()" | |
| ) | |
| btn_export_figure.click( | |
| fn=self.save_figure, | |
| inputs=None, | |
| outputs=[btn_export_figure_hidden] | |
| ).then( | |
| fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_figure_hidden').click()" | |
| ) | |
| figure_extension_selector.change(self.update_figure_extension, inputs=figure_extension_selector) | |
| # events for options | |
| grid_resolution_slider.change(self.update_resolution, inputs=grid_resolution_slider, outputs=self.data_image) | |
| # image_dpi_slider.change(self.update_dpi, inputs=image_dpi_slider, outputs=self.data_image) | |
| cmap_textbox.submit(self.update_cmap, inputs=cmap_textbox, outputs=self.data_image) | |
| precision_slider.change(self.update_precision, inputs=precision_slider, outputs=data_table) | |
| marker_size_slider.change(self.update_marker_size, inputs=marker_size_slider, outputs=self.data_image) | |
| demo.launch() | |
| visualizer = InteractiveDecisionBoundary(width=1200, height=900) | |
| visualizer.launch() | |