decision_boundary / old_code /decision_boundary.py
joel-woodfield's picture
Basic implementation of react version
9b05cbd
from collections import deque
from pathlib import Path
import pickle
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import black
import cv2
import inspect
import numpy as np
import pandas as pd
import io
from jinja2 import Template
from PIL import Image
import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.base import ClassifierMixin
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.metrics import classification_report
import traceback
import yaml
from util import *
import logging
logging.basicConfig(
level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL)
format="%(asctime)s [%(levelname)s] %(message)s", # log format
)
logger = logging.getLogger("ELVIS")
# TODO:
# - support for session: load a previous session and continue from there
def label2color(labels, cmap=None):
'''
Parameters
----------
labels: a list of distinct strings
'''
if (cmap is not None) and (cmap != ''):
# sample the number of needed colors from user-specified color map
color_deque = deque(plt.get_cmap(cmap, len(labels)).colors)
elif len(labels) <= 10:
#color_deque = deque(["red", "green", "blue", "yellow", "orange", "purple", "pink", "brown", "gray", "black"])
#color_deque = deque(mcolors.TABLEAU_COLORS.keys())
color_deque = deque([c.replace('tab:', '') for c in mcolors.TABLEAU_COLORS])
elif len(labels) <= 148:
color_deque = deque(mcolors.CSS4_COLORS)
elif len(labels) <= 949:
color_deque = deque([c.replace('xkcd:', '') for c in mcolors.XKCD_COLORS])
else: # very unlikely
color_deque = deque(plt.get_cmap('vridis', len(labels)))
colors = []
for label in labels:
print(label, color_deque)
if label.lower() in color_deque:
colors.append(label.lower())
color_deque.remove(label.lower())
else:
colors.append(color_deque.popleft())
return colors
def toydata():
points = [['Red', 0.12375, 0.8516666666666667],
['Red', 0.19, 0.8916666666666666],
['Red', 0.27375, 0.9233333333333333],
['Blue', 0.50625, 0.785],
['Blue', 0.38375, 0.6733333333333333],
['Blue', 0.28875, 0.595]]
df = pd.DataFrame(points, columns=['label', 'F1', 'F2'])
return df
class CoordinateProjection2d:
"""
Project data on the two coordinates.
"""
def __init__(self, dim0=0, dim1=1, **kwargs):
self.dims = [dim0, dim1]
def fit(self, X):
self.mean = X.mean(axis=0)
return self
def transform(self, X):
self._check_dims(X)
return X[:, self.dims]
def fit_transform(self, X):
self.fit(X)
return self.transform(X)
def inverse_transform(self, Z):
X = np.ones((len(Z), 1)) * self.mean
self._check_dims(X)
X[:, self.dims] = Z
return X
def _check_dims(self, X):
n_features = X.shape[1]
if self.dims[0] >= n_features:
raise ValueError(f"dim0={self.dims[0]} exceeds the number of features {n_features}")
if self.dims[1] >= n_features:
raise ValueError(f"dim1={self.dims[1]} exceeds the number of features {n_features}")
class InteractiveDecisionBoundary:
DATASET_FILE = "dataset.csv"
MODEL_FILE = "model.pkl"
FIGURE_BASENAME = "figure"
CODE_FILE = "generated_code.py"
EXPORT_CODE_TEMPLATE = "export_code_template.py.j2"
def __init__(self, width, height):
# initialized in draw_plot
#self.canvas_width = -1
#self.canvas_height = -1
self.canvas_width = width
self.canvas_height = height
supported_classifier_names = yaml.safe_load(Path("model_imports.yaml").read_text())
self.classifiers = {k: v for k, v in get_sklearn_classifiers().items() if k in supported_classifier_names}
self.model_class = LinearSVC
self.model_args = ""
self.model = self.model_class()
self.dataloaders = get_sklearn_dataloaders()
supported_embedder_names = yaml.safe_load(Path("embedder_imports.yaml").read_text())
self.embedders = {
'CoordinateProjection2d': CoordinateProjection2d,
'GaussianRandomProjection': sklearn.random_projection.GaussianRandomProjection,
'SparseRandomProjection': sklearn.random_projection.SparseRandomProjection,
}
module = getattr(sklearn, 'decomposition')
for cls_name, cls in inspect.getmembers(module, inspect.isclass):
if cls_name in supported_embedder_names:
self.embedders[cls_name] = cls
# normalizers
self.normalizers = {
'None': None,
'MinMaxScaler': sklearn.preprocessing.MinMaxScaler,
'StandardScaler': sklearn.preprocessing.StandardScaler,
}
# data embedding and preprocessing values
self.embedder_class = CoordinateProjection2d
self.embedder_args = ""
self.normalizer_class = None
self.jitter_std = 0
# todo: support arbitrary number of classes and user-defined class labels
#self.dataset = toydata()
#iris = load_iris(as_frame=True)
#self.dataset = pd.concat([iris.data, iris.target], axis=1)
#self.dataset = self.dataset.rename(columns={'target': 'label'})
self.dataset = pd.DataFrame(columns=['target', 'F1', 'F2'])
self.dataset_type = 'Draw2D'
self.custom_selected = True
# options
self.num_dots = 200
self.dpi = 100
self.cmap = None
self.precision = 2 # number of decimal places to show in datatable
self.marker_size = 100
self.data_image = None
self.boundary_image = None
self._axis_topleft = (0, 0)
self.figure_extension = ".svg"
self.css ="""
#my-button {
height: 30px;
font-size: 16px;
}
#rowheight {
height: 90px;
}
.file-chooser {
height: 150px;
}
.hidden-button {
display: none;
}
.report-table {
border: 0 !important;
}
.report-table tr, .report-table th, .report-table td, .report-table tbody, .report-table thead {
border: 0 !important;
padding: 6px 12px;
text-align: center;
}"""
def _get_features(self):
"""Get the feature values from the current dataset, applying normalization if set."""
X = self.dataset.loc[:, self.dataset.columns != 'target'].values
if len(X) == 0:
raise ValueError("The dataset is empty or not properly formatted.")
return X
def _process_features(self, features):
if self.normalizer_class is not None:
normalizer = self.normalizer_class()
features = normalizer.fit_transform(features)
if self.jitter_std > 0:
noise = np.random.normal(0, self.jitter_std, features.shape)
features += noise
return features
def _embed_features(self, features, return_embedder=False):
embedder = self.embedder_class(n_components=2, **parse_param_string(self.embedder_args))
features = embedder.fit_transform(features)
if return_embedder:
return features, embedder
return features
def _reset_data_processing_and_embedding(self):
# Reset the values
self.normalizer_class = None
self.jitter_std = 0
self.embedder_class = CoordinateProjection2d
self.embedder_args = ""
def plot(self, decision_boundary=False, save_figure=False):
'''
Plot data and decision boundary with matplotlib and return as PIL image.
'''
logger.info("Initializing figure")
fig = plt.figure(figsize=(self.canvas_width/100., self.canvas_height/100.0), dpi=self.dpi)
# set entire figure to be the canvas to allow simple conversion of mouse
# position to coordinates in the figure
ax = fig.add_axes([0., 0., 1., 1.]) #
ax.margins(x=0, y=0) # no padding in both directions
if self.dataset_type == 'Draw2D':
# draw canvas boundary
#ax.scatter([0, 0, 1, 1], [0, 1, 0, 1], color='brown')
# DO NOT CHANGE THE COLOR OF THE BOUNDARY
# IT WILL BREAK THE ORIGIN COORDINATE DETECTION
ax.plot([0, 0, 1, 1, 0], [0, 1, 1, 0, 0], color='black')
for spine in ax.spines.values():
spine.set_color((0.1, 0.1, 0.1))
# TODO: allow showing x and y axes with ticks and labels
if (self.dataset is not None and len(self.dataset) > 0):
try:
X = self._get_features()
y = self.dataset.target.values
logger.info("Data:\n" + str(X))
logger.info("Target:\n" + str(y))
# preprocess features
X = self._process_features(X)
# embed features to 2D for visualization
Z, embedder = self._embed_features(X, return_embedder=True)
#ax.set_title("Click to add points")
labels = np.unique(y)
colors = label2color(labels, cmap=self.cmap)
logger.info("Classes:\n" + str(labels))
logger.info("Colors:\n" + str(colors))
l2c = dict(zip(labels, colors))
# scatter plots for data
for l, label in enumerate(labels):
#print('class', label)
#ax.scatter(*zip(*self.dataset[self.dataset.label == label].features), color=label, label=label)
subset = Z[y == label]
ax.scatter(subset[:, 0], subset[:, 1], color=colors[l], label=label, s=self.marker_size)
ax.legend()
# plot the decision boundary
if decision_boundary:
model = self.model_class(**parse_param_string(self.model_args))
model.fit(X, y)
self.model = model
# plot decision boundary in the projected space
# xx, yy = np.meshgrid(np.linspace(Z[:, 0].min(), 1, 100), np.linspace(0, 1, 100))
# Note: Should not apply normalization/jittering to meshgrid points
if self.dataset_type == 'Draw2D':
xx, yy = np.meshgrid(np.linspace(0, 1, self.num_dots),
np.linspace(0, 1, self.num_dots))
else:
xx, yy = np.meshgrid(np.linspace(Z[:, 0].min(), Z[:, 0].max(), self.num_dots),
np.linspace(Z[:, 1].min(), Z[:, 1].max(), self.num_dots))
grid = np.c_[xx.ravel(), yy.ravel()]
#scores = clf.decision_function(grid)[:, 1].reshape(xx.shape)
#scores = clf.decision_function(grid).reshape(xx.shape)
#ax.contour(xx, yy, scores)#, levels=[0], colors="black", linestyles="--")
print('grid', grid)
print('inverse', embedder.inverse_transform(grid))
preds = model.predict(embedder.inverse_transform(grid)).reshape(xx.shape)
#print(preds.shape, xx.shape, yy.shape)
ax.scatter(xx.ravel(), yy.ravel(), c=[l2c[l] for l in preds.ravel()], s=1, alpha=0.5)
except Exception as e:
print(traceback.format_exc())
#raise gr.Error(f"⚠️ {e}")
gr.Info(f"⚠️ {e}")
buf = io.BytesIO()
ax.figure.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
plt.close(fig)
buf.seek(0)
img = Image.open(buf)
if save_figure:
ax.figure.savefig(f"{self.FIGURE_BASENAME}{self.figure_extension}")
# detect axis pixel positions
if self.dataset_type == 'Draw2D':
array = np.array(img.convert("RGB"))
bgr = cv2.cvtColor(array, cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
black_mask = gray < 0.05 * 255
contours, _ = cv2.findContours(black_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# find the contour with the largest area
max_area = 0
most_likely_topleft = 0, 0
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
area = w * h
if w * h > max_area:
max_area = area
most_likely_topleft = x, y
self._axis_topleft = most_likely_topleft
else:
self._axis_topleft = 0, 0
# TODO: add a save function for saving screenshot
#img.save('image.png')
return img
def update_resolution(self, num_dots):
self.num_dots = num_dots
return self.plot(decision_boundary=True)
def update_dpi(self, dpi):
self.dpi = dpi
return self.plot(decision_boundary=True)
def update_cmap(self, cmap):
self.cmap = cmap
return self.plot(decision_boundary=True)
def update_precision(self, precision):
self.precision = precision
data_table = gr.Dataframe(
value=self.dataset.round(self.precision),
visible=True,
headers=list(self.dataset.columns),
)
return data_table
def update_marker_size(self, size):
self.marker_size = size
return self.plot(decision_boundary=True)
def add_point(self, evt: gr.SelectData, label):
'''
Mouse click to add a point.
'''
if self.custom_selected:
if self.dataset_type != 'Draw2D':
self.dataset = pd.DataFrame(columns=['target', 'F1', 'F2'])
self.dataset_type = 'Draw2D'
shift_x, shift_y = self._axis_topleft
# normalize clicked position to [0, 1]
x = (evt.index[0] - shift_x) / self.canvas_width
y = 1 - (evt.index[1] - shift_y) / self.canvas_height # flip y-axis to match matplotlib
if 0 <= x <= 1 and 0 <= y <= 1:
self.dataset.loc[len(self.dataset)] = [label, x, y]
logger.info(f'clicked ({evt.index[0]}, {evt.index[1]}), mapped to ({x}, {y})')
vis = self.plot()
data_table = gr.Dataframe(
value=self.dataset.round(self.precision),
visible=True,
headers=list(self.dataset.columns),
)
return vis, data_table
# train a model and show decision boundary
def train(self):
image = self.plot(decision_boundary=True)
try:
X = self.dataset.loc[:, self.dataset.columns != 'target'].values
y = self.dataset.target.values
pred = self.model.predict(X)
df = pd.DataFrame(classification_report(y, pred, output_dict=True)).T
summary = df.to_html(classes="report-table", float_format="%.2f")
return image, gr.HTML(visible=True), "<b>Classification report</b><br>" + summary
except Exception as e:
print(traceback.format_exc())
gr.Info(f"⚠️ {e}")
return image, gr.HTML(visible=False), ""
# clear data points and data preprocessing
def clear(self):
self.dataset = pd.DataFrame(columns=['target', 'F1', 'F2'])
return self.plot(), gr.Dataframe(visible=False)
def save(self):
self.save_data()
# save dataset
def save_data(self):
# TODO: allow user-specified filename
self.dataset.to_csv(self.DATASET_FILE, index=False)
logger.info(f"{self.DATASET_FILE} updated")
return self.DATASET_FILE
def update_model(self, classifier_name):
self.model_class = self.classifiers[classifier_name]
self.args_textbox.value = ""
logger.info(f'Updated model to {self.model_class}')
return ""
def save_model(self):
with open(self.MODEL_FILE, "wb") as f:
pickle.dump(self.model, f)
logger.info(f"{self.MODEL_FILE} updated")
return self.MODEL_FILE
def save_code(self):
model_class = str(self.model_class.__name__)
model_imports = yaml.safe_load(Path("model_imports.yaml").read_text())
if model_class not in model_imports:
raise ValueError(f"Model {model_class} not found in model_imports.yaml")
model_import_stmt = f"{model_imports[model_class]}"
embedder_class = str(self.embedder_class.__name__)
if embedder_class == "CoordinateProjection2d":
embedder_import_stmt = f"\n\n\n{inspect.getsource(CoordinateProjection2d)}".rstrip()
else:
embedder_imports = yaml.safe_load(Path("embedder_imports.yaml").read_text())
if embedder_class not in embedder_imports:
raise ValueError(f"Embedder {embedder_class} not found in embedder_imports.yaml")
embedder_import_stmt = f"\n{embedder_imports[embedder_class]}"
if self.normalizer_class is not None:
normalizer_class = str(self.normalizer_class.__name__)
normalizer_imports = yaml.safe_load(Path("normalizer_imports.yaml").read_text())
if normalizer_class not in normalizer_imports:
raise ValueError(f"Normalizer {normalizer_class} not found in normalizer_imports.yaml")
normalizer_import_stmt = f"\n{normalizer_imports[normalizer_class]}"
else:
normalizer_import_stmt = ""
if self.dataset_type == 'Draw2D':
x_min = 0
x_max = 1
y_min = 0
y_max = 1
else:
x_min = "X_embedded[:, 0].min()"
x_max = "X_embedded[:, 0].max()"
y_min = "X_embedded[:, 1].min()"
y_max = "X_embedded[:, 1].max()"
model_params = parse_param_string(self.model_args)
if len(model_params) == 0:
model_params_text = ""
else:
model_params_text = "".join([f"\n\t\t{k}={repr(v)}," for k, v in model_params.items()]) + "\n\t"
if embedder_class == "CoordinateProjection2d":
embedder_args = {**parse_param_string(self.embedder_args)}
else:
embedder_args = {"n_components": 2, **parse_param_string(self.embedder_args)}
embedder_args_text = "".join([f"\n\t\t{k}={repr(v)}," for k, v in embedder_args.items()]) + "\n\t"
template = Template(Path(self.EXPORT_CODE_TEMPLATE).read_text())
variables = {
'model_import_statement': model_import_stmt,
'embedder_import_statement': embedder_import_stmt,
"normalizer_import_statement": normalizer_import_stmt,
'dataset_file': self.DATASET_FILE,
'embedder_class': embedder_class,
'embedder_args': embedder_args_text,
'model_class': model_class,
'model_params': model_params_text,
'fig_width': self.canvas_width / 100,
'fig_height': self.canvas_height / 100,
'dpi': 100,
'num_dots': self.num_dots,
'x_min': x_min,
'x_max': x_max,
'y_min': y_min,
'y_max': y_max,
"normalize": self.normalizer_class is not None,
"normalizer_class": self.normalizer_class.__name__ if self.normalizer_class is not None else "",
"jitter": self.jitter_std > 0,
"jitter_scale": self.jitter_std,
}
rendered_code = template.render(variables)
rendered_code = black.format_str(rendered_code, mode=black.FileMode())
Path(self.CODE_FILE).write_text(rendered_code)
logger.info(f"{self.CODE_FILE} updated")
return self.CODE_FILE
def update_figure_extension(self, ext):
self.figure_extension = ext
print('updated figure extension:', self.figure_extension)
def save_figure(self):
self.plot(decision_boundary=True, save_figure=True)
return f"{self.FIGURE_BASENAME}{self.figure_extension}"
def update_args(self, model_args):
self.model_args = model_args
print('updated model_args:', self.model_args)
def update_embedder(self, embedder):
self.embedder_class = self.embedders[embedder]
print('updated Embedder:', self.embedder_class)
return self.plot()
def update_embedder_args(self, embedder_args):
self.embedder_args = embedder_args
print('updated Embedder args:', self.embedder_args)
return self.plot()
def update_normalizer(self, normalizer):
self.normalizer_class = self.normalizers[normalizer]
print('updated Normalizer:', self.normalizer_class)
data_table = gr.Dataframe(
value=self.dataset[:100].round(self.precision),
visible=True,
headers=list(self.dataset.columns),
)
return self.plot(), data_table
def update_jittering(self, jitter_std):
try:
self.jitter_std = float(jitter_std)
except ValueError:
self.jitter_std = 0
print('updated Jittering std:', self.jitter_std)
return self.plot()
def handle_dataset_radio(self, type):
if type == 'Draw2D':
self.custom_selected = True
self.dataset_type = "Draw2D"
new_fields = gr.File(visible=False), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=False, value=0), gr.Dropdown(visible=False, value="CoordinateProjection2d"), gr.Textbox(visible=False), gr.Textbox(visible=True), gr.Button(visible=True), gr.Button(visible=False)
elif type == 'Upload':
self.dataset_type = "Upload"
self.custom_selected = False
new_fields = gr.File(visible=True), gr.Dropdown(visible=False, value="None"), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value=0), gr.Dropdown(visible=True, value="CoordinateProjection2d"), gr.Textbox(visible=True), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False)
elif type == 'sklearn':
self.dataset_type = "sklearn"
self.custom_selected = False
new_fields = gr.File(visible=False), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value="None"), gr.Dropdown(visible=True, value=0), gr.Dropdown(visible=True, value="CoordinateProjection2d"), gr.Textbox(visible=True), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False)
else:
# TODO: better error handling
print('Error - unknown dataset type:', type)
self._reset_data_processing_and_embedding()
plot, data_table = self.clear()
return plot, data_table, *new_fields
def load_local_data_and_plot(self, filename):
if filename is not None:
self.dataset = read(filename)
self.dataset.target = self.dataset.target.astype(str)
self.dataset_type = 'Upload'
logger.info(f'Loaded dataset from {filename}')
vis = self.plot()
#data_html = self.dataset.to_html(classes="report-table", float_format="%.2f")
# TODO: need to make it explicit that this only shows first 100 points
data_table = gr.Dataframe(
value=self.dataset[:100].round(self.precision),
visible=True,
headers=list(self.dataset.columns)
)
return vis, data_table
def load_sklearn_data_and_plot(self, datasetname):
if datasetname is not None and datasetname != "None":
dataset = self.dataloaders[datasetname]()
X = dataset.data
y = dataset.target
if hasattr(dataset, 'feature_names'):
feature_names = dataset.feature_names
else:
feature_names = ['F{%d}' % i for i in range(len(X[0]))]
if hasattr(dataset, 'target_names'):
labels = dataset.target_names
else:
labels = ['C{%d}' % i for i in range(len(np.unique(y)))]
y = np.array([labels[i] for i in y])
self.dataset = pd.DataFrame(X, columns=feature_names)
self.dataset['target'] = y.astype(str)
self.dataset_type = 'sklearn'
logger.info(f'Loaded dataset {datasetname}')
vis = self.plot()
#data_html = self.dataset.to_html(classes="report-table", float_format="%.2f")
# TODO: need to make it explicit that this only shows first 100 points
data_table = gr.Dataframe(
value=self.dataset[:100].round(self.precision), visible=True,
headers=list(self.dataset.columns)
)
return vis, data_table
def launch(self):
# build the Gradio interface
with gr.Blocks(css=self.css) as demo:
# app title
gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Interactive Decision Boundary Visualizer</div>")
# GUI elements and layout
with gr.Row():
with gr.Column(scale=2):
self.data_image = gr.Image(
value=self.plot(),
container=True,
show_share_button=False,
show_fullscreen_button=False,
show_download_button=False,
show_label=False,
)
with gr.Column(scale=1):
with gr.Tab("Dataset"):
dataset_radio = gr.Radio(
["Draw2D", "Upload", "sklearn"],
value="Draw2D",
label="Dataset type",
elem_id="rowheight",
)
# upload data
file_chooser = gr.File(label="Choose a file", visible=False, elem_classes="file-chooser")
self.file_chooser = file_chooser
# sklearn data dropdown menu
sklearn_data_selector = gr.Dropdown(
choices=self.dataloaders,
label='Select dataset',
value='None',
visible=False,
allow_custom_value=True,
)
self.sklearn_data_selector = sklearn_data_selector
# normalization
normalizer_selector = gr.Dropdown(
choices=self.normalizers,
label='Select normalizer',
value='None',
visible=False,
)
self.normalizer_selector = normalizer_selector
# jittering
jittering_textbox = gr.Textbox(label="Set jittering std", value="0", visible=False)
self.jittering_textbox = jittering_textbox
# embedder
embedder_selector = gr.Dropdown(
choices=self.embedders,
label='Select embedder (only for plotting)',
value='CoordinateProjection2d',
visible=False,
allow_custom_value=True,
)
self.embedder_selector = embedder_selector
embedder_args_textbox = gr.Textbox(label="Embedder arguments", visible=False)
self.embedder_args_textbox = embedder_args_textbox
# custom data
label = gr.Radio(["Gray", "Orange", "Blue"], value="Gray", label="Choose point label", visible=True, elem_id="rowheight")
self.label = label
with gr.Row():
btn_clear = gr.Button("Clear", visible=True, elem_id="my-button")
self.btn_clear = btn_clear
btn_save = gr.Button("Save", visible=False, elem_id="my-button")
self.btn_save = btn_save
#data_html = gr.HTML(visible=True)
data_table = gr.Dataframe(visible=False)
# classifier selector
with gr.Tab("Classifier"):
# specify model
model_selector = gr.Dropdown(choices=self.classifiers,
#label='',
#value='Select classifier',
label='Select Classifier',
value='LinearSVC',
allow_custom_value=True)
self.model_selector = model_selector
# specify arguments
args_textbox = gr.Textbox(label="Classifier arguments")
self.args_textbox = args_textbox
model_selector.change(fn=self.update_model, inputs=model_selector, outputs=args_textbox)
btn_train = gr.Button("Train Model")
classification_summary = gr.HTML(visible=False)
with gr.Tab("Export"):
# use hidden download button to generate files on the fly
# https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
btn_export_data = gr.Button("Data")
btn_export_data_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_data_hidden", elem_classes="hidden-button")
btn_export_model = gr.Button('Model')
btn_export_model_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_model_hidden", elem_classes="hidden-button")
btn_export_code = gr.Button('Code')
btn_export_code_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_code_hidden", elem_classes="hidden-button")
with gr.Row():
btn_export_figure = gr.Button('Figure')
btn_export_figure_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_figure_hidden", elem_classes="hidden-button")
figure_extension_selector = gr.Dropdown(choices=['.svg', '.pdf', '.png', '.jpeg'], label="File extension", value=".svg")
with gr.Tab("Options"):
grid_resolution_slider = gr.Slider(minimum=100, maximum=1000, value=200, step=10, label="Decision boundary grid resolution")
# image_dpi_slider = gr.Slider(minimum=100, maximum=1000, value=100, step=10, label="Image DPI")
cmap_textbox = gr.Textbox(label="Colormap")
precision_slider = gr.Slider(minimum=0, maximum=20, value=2, step=1, label="# decimal place in datatable")
marker_size_slider = gr.Slider(minimum=0, maximum=200, value=100, step=5, label="Marker size")
with gr.Tab("Usage"):
gr.Markdown(''.join(open('usage.md', 'r').readlines()))
# event handlers for GUI elements
self.data_image.select(self.add_point, inputs=label,
outputs=(self.data_image, data_table))
dataset_radio.change(
fn=self.handle_dataset_radio,
inputs=dataset_radio,
outputs=(
self.data_image, data_table, file_chooser, sklearn_data_selector, normalizer_selector, jittering_textbox, embedder_selector, embedder_args_textbox, label, btn_clear, btn_save
),
)
# events for custom dataset
btn_clear.click(fn=self.clear, outputs=(self.data_image, data_table))
btn_save.click(fn=self.save)
# events for local dataset
file_chooser.change(fn=self.load_local_data_and_plot,
inputs=file_chooser,
outputs=(self.data_image, data_table))
# events for sklearn dataset
sklearn_data_selector.change(fn=self.load_sklearn_data_and_plot,
inputs=sklearn_data_selector,
outputs=(self.data_image, data_table))
embedder_selector.change(fn=self.update_embedder,
inputs=embedder_selector,
outputs=self.data_image)
embedder_args_textbox.change(
fn=self.update_embedder_args,
inputs=embedder_args_textbox,
outputs=self.data_image,
)
normalizer_selector.change(
fn=self.update_normalizer,
inputs=normalizer_selector,
outputs=(self.data_image, data_table),
)
jittering_textbox.change(
fn=self.update_jittering,
inputs=jittering_textbox,
outputs=self.data_image,
)
btn_train.click(fn=self.update_args, inputs=args_textbox)
btn_train.click(fn=self.train, outputs=(self.data_image, classification_summary, classification_summary))
# events for export
# create files on the fly using hidden download buttons
# https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
btn_export_data.click(
fn=self.save_data,
inputs=None,
outputs=[btn_export_data_hidden]
).then(
fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_data_hidden').click()"
)
btn_export_model.click(
fn=self.save_model,
inputs=None,
outputs=[btn_export_model_hidden]
).then(
fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_model_hidden').click()"
)
btn_export_code.click(
fn=self.save_code,
inputs=None,
outputs=[btn_export_code_hidden]
).then(
fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_code_hidden').click()"
)
btn_export_figure.click(
fn=self.save_figure,
inputs=None,
outputs=[btn_export_figure_hidden]
).then(
fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_figure_hidden').click()"
)
figure_extension_selector.change(self.update_figure_extension, inputs=figure_extension_selector)
# events for options
grid_resolution_slider.change(self.update_resolution, inputs=grid_resolution_slider, outputs=self.data_image)
# image_dpi_slider.change(self.update_dpi, inputs=image_dpi_slider, outputs=self.data_image)
cmap_textbox.submit(self.update_cmap, inputs=cmap_textbox, outputs=self.data_image)
precision_slider.change(self.update_precision, inputs=precision_slider, outputs=data_table)
marker_size_slider.change(self.update_marker_size, inputs=marker_size_slider, outputs=self.data_image)
demo.launch()
visualizer = InteractiveDecisionBoundary(width=1200, height=900)
visualizer.launch()