Spaces:
Running
Running
initial commit
Browse files- README.md +5 -5
- decision_boundary.py +452 -0
- requirements.txt +6 -0
- util.py +92 -0
README.md
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.44.1
|
| 8 |
-
app_file:
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
|
|
|
| 1 |
---
|
| 2 |
+
title: ELVIS Decision Boundary
|
| 3 |
+
emoji: 📊
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.44.1
|
| 8 |
+
app_file: decision_boundary.py
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
decision_boundary.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import inspect
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import io
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import sklearn
|
| 9 |
+
from sklearn.linear_model import LogisticRegression
|
| 10 |
+
from sklearn.svm import LinearSVC
|
| 11 |
+
from sklearn.base import ClassifierMixin
|
| 12 |
+
from sklearn.datasets import load_iris
|
| 13 |
+
from sklearn.decomposition import PCA
|
| 14 |
+
from sklearn.metrics import classification_report
|
| 15 |
+
|
| 16 |
+
from collections import deque
|
| 17 |
+
|
| 18 |
+
from util import *
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
# Configure the logger once at the start of your program
|
| 22 |
+
logging.basicConfig(
|
| 23 |
+
level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
| 24 |
+
format="%(asctime)s [%(levelname)s] %(message)s", # log format
|
| 25 |
+
)
|
| 26 |
+
logger = logging.getLogger("ELVIS")
|
| 27 |
+
|
| 28 |
+
# TODO:
|
| 29 |
+
# - support for session: load a previous session and continue from there
|
| 30 |
+
|
| 31 |
+
def label2color(labels):
|
| 32 |
+
color_deque = deque(["red", "green", "blue", "yellow", "orange", "purple", "pink", "brown", "gray", "black"])
|
| 33 |
+
|
| 34 |
+
colors = []
|
| 35 |
+
|
| 36 |
+
for label in labels:
|
| 37 |
+
if label.lower() in color_deque:
|
| 38 |
+
colors.append(label.lower())
|
| 39 |
+
color_deque.remove(label.lower())
|
| 40 |
+
else:
|
| 41 |
+
colors.append(color_deque.popleft())
|
| 42 |
+
|
| 43 |
+
return colors
|
| 44 |
+
|
| 45 |
+
def toydata():
|
| 46 |
+
points = [['Red', 0.12375, 0.8516666666666667],
|
| 47 |
+
['Red', 0.19, 0.8916666666666666],
|
| 48 |
+
['Red', 0.27375, 0.9233333333333333],
|
| 49 |
+
['Blue', 0.50625, 0.785],
|
| 50 |
+
['Blue', 0.38375, 0.6733333333333333],
|
| 51 |
+
['Blue', 0.28875, 0.595]]
|
| 52 |
+
df = pd.DataFrame(points, columns=['label', 'F1', 'F2'])
|
| 53 |
+
return df
|
| 54 |
+
|
| 55 |
+
class CoordinateProjection:
|
| 56 |
+
'''
|
| 57 |
+
TODO: allow user to specify different coordinates
|
| 58 |
+
'''
|
| 59 |
+
|
| 60 |
+
def __init__(self, n_components=None, dims=[0, 1]):
|
| 61 |
+
self.dims = dims
|
| 62 |
+
|
| 63 |
+
def fit_transform(self, X):
|
| 64 |
+
print(X)
|
| 65 |
+
self.mean = X.mean(axis=0)
|
| 66 |
+
return X[:, self.dims]
|
| 67 |
+
|
| 68 |
+
def transform(self, X):
|
| 69 |
+
return X[:, self.dims]
|
| 70 |
+
|
| 71 |
+
def inverse_transform(self, Z):
|
| 72 |
+
X = np.ones((len(Z), 1)) * self.mean
|
| 73 |
+
X[:, self.dims] = Z
|
| 74 |
+
return X
|
| 75 |
+
|
| 76 |
+
class InteractiveDecisionBoundary:
|
| 77 |
+
def __init__(self, width, height):
|
| 78 |
+
# initialized in draw_plot
|
| 79 |
+
#self.canvas_width = -1
|
| 80 |
+
#self.canvas_height = -1
|
| 81 |
+
|
| 82 |
+
self.canvas_width = width
|
| 83 |
+
self.canvas_height = height
|
| 84 |
+
|
| 85 |
+
self.classifiers = get_sklearn_classifiers()
|
| 86 |
+
self.dataloaders = get_sklearn_dataloaders()
|
| 87 |
+
|
| 88 |
+
self.embedders = {
|
| 89 |
+
'CoordinateProjection': CoordinateProjection,
|
| 90 |
+
'GaussianRandomProjection': sklearn.random_projection.GaussianRandomProjection,
|
| 91 |
+
'SparseRandomProjection': sklearn.random_projection.SparseRandomProjection,
|
| 92 |
+
}
|
| 93 |
+
module = getattr(sklearn, 'decomposition')
|
| 94 |
+
for cls_name, cls in inspect.getmembers(module, inspect.isclass):
|
| 95 |
+
self.embedders[cls_name] = cls
|
| 96 |
+
|
| 97 |
+
self.Embedder = CoordinateProjection
|
| 98 |
+
#self.Embedder = PCA
|
| 99 |
+
|
| 100 |
+
# default classifier model
|
| 101 |
+
#self.Model = LogisticRegression
|
| 102 |
+
self.Model = LinearSVC
|
| 103 |
+
self.model_args = {}
|
| 104 |
+
|
| 105 |
+
# todo: support arbitrary number of classes and user-defined class labels
|
| 106 |
+
#self.dataset = toydata()
|
| 107 |
+
#iris = load_iris(as_frame=True)
|
| 108 |
+
#self.dataset = pd.concat([iris.data, iris.target], axis=1)
|
| 109 |
+
#self.dataset = self.dataset.rename(columns={'target': 'label'})
|
| 110 |
+
self.dataset = pd.DataFrame(columns=['target', 'F1', 'F2'])
|
| 111 |
+
self.dataset_type = 'Draw2D'
|
| 112 |
+
self.custom_selected = True
|
| 113 |
+
|
| 114 |
+
self.data_image = None
|
| 115 |
+
self.boundary_image = None
|
| 116 |
+
|
| 117 |
+
self.css ="""
|
| 118 |
+
#my-button {
|
| 119 |
+
height: 30px;
|
| 120 |
+
font-size: 16px;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
#rowheight {
|
| 124 |
+
height: 90px;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
.report-table {
|
| 128 |
+
border: 0 !important;
|
| 129 |
+
}
|
| 130 |
+
.report-table tr, .report-table th, .report-table td, .report-table tbody, .report-table thead {
|
| 131 |
+
border: 0 !important;
|
| 132 |
+
padding: 6px 12px;
|
| 133 |
+
text-align: center;
|
| 134 |
+
}"""
|
| 135 |
+
def plot(self, decision_boundary=False):
|
| 136 |
+
'''
|
| 137 |
+
Plot data and decision boundary with matplotlib and return as PIL image.
|
| 138 |
+
'''
|
| 139 |
+
|
| 140 |
+
logger.info("Initializing figure")
|
| 141 |
+
fig = plt.figure(figsize=(self.canvas_width/100., self.canvas_height/100.0), dpi=100)
|
| 142 |
+
# set entire figure to be the canvas to allow simple conversion of mouse
|
| 143 |
+
# position to coordinates in the figure
|
| 144 |
+
ax = fig.add_axes([0., 0., 1., 1.]) #
|
| 145 |
+
ax.margins(x=0, y=0) # no padding in both directions
|
| 146 |
+
|
| 147 |
+
if self.dataset_type == 'Draw2D':
|
| 148 |
+
# draw canvas boundary
|
| 149 |
+
#ax.scatter([0, 0, 1, 1], [0, 1, 0, 1], color='brown')
|
| 150 |
+
ax.plot([0, 0, 1, 1, 0], [0, 1, 1, 0, 0], color='brown')
|
| 151 |
+
ax.axis('off')
|
| 152 |
+
|
| 153 |
+
# TODO: allow showing x and y axes with ticks and labels
|
| 154 |
+
if (self.dataset is not None and len(self.dataset) > 0):
|
| 155 |
+
try:
|
| 156 |
+
X = self.dataset.loc[:, self.dataset.columns != 'target'].values
|
| 157 |
+
y = self.dataset.target.values
|
| 158 |
+
logger.info("Data:\n" + str(X))
|
| 159 |
+
logger.info("Target:\n" + str(y))
|
| 160 |
+
|
| 161 |
+
# compute embedding
|
| 162 |
+
embedder = self.Embedder(n_components=2)
|
| 163 |
+
self.embedder = embedder
|
| 164 |
+
logger.info("Embedder = " + str(self.embedder))
|
| 165 |
+
|
| 166 |
+
Z = embedder.fit_transform(X)
|
| 167 |
+
logger.info("Projected data:\n" + str(Z))
|
| 168 |
+
|
| 169 |
+
#ax.set_title("Click to add points")
|
| 170 |
+
labels = np.unique(y)
|
| 171 |
+
colors = label2color(labels)
|
| 172 |
+
logger.info("Classes:\n" + str(labels))
|
| 173 |
+
logger.info("Colors:\n" + str(colors))
|
| 174 |
+
l2c = dict(zip(labels, colors))
|
| 175 |
+
|
| 176 |
+
# scatter plots for data
|
| 177 |
+
for l, label in enumerate(labels):
|
| 178 |
+
#print('class', label)
|
| 179 |
+
#ax.scatter(*zip(*self.dataset[self.dataset.label == label].features), color=label, label=label)
|
| 180 |
+
subset = Z[y == label]
|
| 181 |
+
ax.scatter(subset[:, 0], subset[:, 1], color=colors[l], label=label)
|
| 182 |
+
ax.legend()
|
| 183 |
+
|
| 184 |
+
# plot the decision boundary
|
| 185 |
+
if decision_boundary:
|
| 186 |
+
model = self.Model(**parse_param_string(self.model_args))
|
| 187 |
+
model.fit(X, y)
|
| 188 |
+
self.model = model
|
| 189 |
+
|
| 190 |
+
# plot decision boundary in the projected space
|
| 191 |
+
# xx, yy = np.meshgrid(np.linspace(Z[:, 0].min(), 1, 100), np.linspace(0, 1, 100))
|
| 192 |
+
if self.dataset_type == 'Draw2D':
|
| 193 |
+
xx, yy = np.meshgrid(np.linspace(0, 1, 100),
|
| 194 |
+
np.linspace(0, 1, 100))
|
| 195 |
+
else:
|
| 196 |
+
xx, yy = np.meshgrid(np.linspace(Z[:, 0].min(), Z[:, 0].max(), 100),
|
| 197 |
+
np.linspace(Z[:, 1].min(), Z[:, 1].max(), 100))
|
| 198 |
+
|
| 199 |
+
grid = np.c_[xx.ravel(), yy.ravel()]
|
| 200 |
+
#scores = clf.decision_function(grid)[:, 1].reshape(xx.shape)
|
| 201 |
+
#scores = clf.decision_function(grid).reshape(xx.shape)
|
| 202 |
+
#ax.contour(xx, yy, scores)#, levels=[0], colors="black", linestyles="--")
|
| 203 |
+
print('grid', grid)
|
| 204 |
+
print('inverse', embedder.inverse_transform(grid))
|
| 205 |
+
preds = model.predict(embedder.inverse_transform(grid)).reshape(xx.shape)
|
| 206 |
+
#print(preds.shape, xx.shape, yy.shape)
|
| 207 |
+
ax.scatter(xx.ravel(), yy.ravel(), c=[l2c[l] for l in preds.ravel()], s=1, alpha=0.5)
|
| 208 |
+
except Exception as e:
|
| 209 |
+
raise gr.Error(f"⚠️ {e}")
|
| 210 |
+
|
| 211 |
+
buf = io.BytesIO()
|
| 212 |
+
ax.figure.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
| 213 |
+
plt.close(fig)
|
| 214 |
+
buf.seek(0)
|
| 215 |
+
img = Image.open(buf)
|
| 216 |
+
|
| 217 |
+
# TODO: add a save function for saving screenshot
|
| 218 |
+
#img.save('image.png')
|
| 219 |
+
|
| 220 |
+
return img
|
| 221 |
+
|
| 222 |
+
def add_point(self, evt: gr.SelectData, label):
|
| 223 |
+
'''
|
| 224 |
+
Mouse click to add a point.
|
| 225 |
+
'''
|
| 226 |
+
if self.custom_selected:
|
| 227 |
+
if self.dataset_type != 'Draw2D':
|
| 228 |
+
self.dataset = pd.DataFrame(columns=['target', 'F1', 'F2'])
|
| 229 |
+
self.dataset_type = 'Draw2D'
|
| 230 |
+
|
| 231 |
+
# normalize clicked position to [0, 1]
|
| 232 |
+
x = evt.index[0] / self.canvas_width
|
| 233 |
+
y = 1 - evt.index[1] / self.canvas_height # flip y-axis to match matplotlib
|
| 234 |
+
|
| 235 |
+
self.dataset.loc[len(self.dataset)] = [label, x, y]
|
| 236 |
+
|
| 237 |
+
logger.info(f'clicked ({evt.index[0]}, {evt.index[1]}), mapped to ({x}, {y})')
|
| 238 |
+
|
| 239 |
+
vis = self.plot()
|
| 240 |
+
data_table = gr.Dataframe(value = self.dataset[:100], visible=True,
|
| 241 |
+
headers=list(self.dataset.columns))
|
| 242 |
+
|
| 243 |
+
return vis, data_table
|
| 244 |
+
|
| 245 |
+
# train a model and show decision boundary
|
| 246 |
+
def train(self):
|
| 247 |
+
image = self.plot(decision_boundary=True)
|
| 248 |
+
X = self.dataset.loc[:, self.dataset.columns != 'target'].values
|
| 249 |
+
y = self.dataset.target.values
|
| 250 |
+
pred = self.model.predict(X)
|
| 251 |
+
df = pd.DataFrame(classification_report(y, pred, output_dict=True)).T
|
| 252 |
+
summary = df.to_html(classes="report-table", float_format="%.2f")
|
| 253 |
+
|
| 254 |
+
return image, gr.HTML(visible=True), "<b>Classification report</b><br>" + summary
|
| 255 |
+
|
| 256 |
+
# clear dataset and replot
|
| 257 |
+
def clear(self):
|
| 258 |
+
self.dataset = self.dataset[0:0]
|
| 259 |
+
return self.plot()
|
| 260 |
+
|
| 261 |
+
# save dataset
|
| 262 |
+
def save(self):
|
| 263 |
+
# TODO: allow user-specified filename
|
| 264 |
+
self.dataset.to_csv('dataset.csv', index=False)
|
| 265 |
+
|
| 266 |
+
def update_model(self, classifier_name):
|
| 267 |
+
self.Model = self.classifiers[classifier_name]
|
| 268 |
+
self.args_textbox.value = ""
|
| 269 |
+
logger.info(f'Updated model to {self.model}')
|
| 270 |
+
|
| 271 |
+
return ""
|
| 272 |
+
|
| 273 |
+
def update_args(self, model_args):
|
| 274 |
+
self.model_args = model_args
|
| 275 |
+
print('updated model_args:', self.model_args)
|
| 276 |
+
|
| 277 |
+
def update_embedder(self, embedder):
|
| 278 |
+
self.Embedder = self.embedders[embedder]
|
| 279 |
+
print('updated embedder:', self.Embedder)
|
| 280 |
+
return self.plot()
|
| 281 |
+
|
| 282 |
+
def handle_dataset_radio(self, type):
|
| 283 |
+
if type == 'Draw2D':
|
| 284 |
+
self.custom_selected = True
|
| 285 |
+
return gr.File(visible=False), gr.Dropdown(visible=False), gr.Dropdown(visible=False), gr.Textbox(visible=True), gr.Button(visible=True), gr.Button(visible=True)
|
| 286 |
+
elif type == 'Upload':
|
| 287 |
+
self.custom_selected = False
|
| 288 |
+
return gr.File(visible=True), gr.Dropdown(visible=False), gr.Dropdown(visible=True), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False)
|
| 289 |
+
elif type == 'sklearn':
|
| 290 |
+
self.custom_selected = False
|
| 291 |
+
return gr.File(visible=False), gr.Dropdown(visible=True), gr.Dropdown(visible=True), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False)
|
| 292 |
+
else:
|
| 293 |
+
# TODO: better error handling
|
| 294 |
+
print('Error - unknown dataset type:', type)
|
| 295 |
+
|
| 296 |
+
def load_local_data_and_plot(self, filename):
|
| 297 |
+
if filename is not None:
|
| 298 |
+
self.dataset = read(filename)
|
| 299 |
+
self.dataset.target = self.dataset.target.astype(str)
|
| 300 |
+
self.dataset_type = 'Upload'
|
| 301 |
+
logger.info(f'Loaded dataset from {filename}')
|
| 302 |
+
|
| 303 |
+
vis = self.plot()
|
| 304 |
+
#data_html = self.dataset.to_html(classes="report-table", float_format="%.2f")
|
| 305 |
+
data_table = gr.Dataframe(value = self.dataset[:100], visible=True,
|
| 306 |
+
headers=list(self.dataset.columns))
|
| 307 |
+
|
| 308 |
+
return vis, data_table
|
| 309 |
+
|
| 310 |
+
def load_sklearn_data_and_plot(self, datasetname):
|
| 311 |
+
if datasetname is not None:
|
| 312 |
+
dataset = self.dataloaders[datasetname]()
|
| 313 |
+
X = dataset.data
|
| 314 |
+
y = dataset.target
|
| 315 |
+
if hasattr(dataset, 'feature_names'):
|
| 316 |
+
feature_names = dataset.feature_names
|
| 317 |
+
else:
|
| 318 |
+
feature_names = ['F{%d}' % i for i in range(len(X[0]))]
|
| 319 |
+
|
| 320 |
+
if hasattr(dataset, 'target_names'):
|
| 321 |
+
labels = dataset.target_names
|
| 322 |
+
else:
|
| 323 |
+
labels = ['C{%d}' % i for i in range(len(np.unique(y)))]
|
| 324 |
+
|
| 325 |
+
self.dataset = pd.DataFrame(X, columns=feature_names)
|
| 326 |
+
self.dataset['target'] = y.astype(str)
|
| 327 |
+
self.dataset_type = 'sklearn'
|
| 328 |
+
logger.info(f'Loaded dataset {datasetname}')
|
| 329 |
+
|
| 330 |
+
vis = self.plot()
|
| 331 |
+
#data_html = self.dataset.to_html(classes="report-table", float_format="%.2f")
|
| 332 |
+
data_table = gr.Dataframe(value = self.dataset[:100], visible=True,
|
| 333 |
+
headers=list(self.dataset.columns))
|
| 334 |
+
return vis, data_table
|
| 335 |
+
|
| 336 |
+
def launch(self):
|
| 337 |
+
# build the Gradio interface
|
| 338 |
+
with gr.Blocks(css=self.css) as demo:
|
| 339 |
+
# app title
|
| 340 |
+
gr.Markdown("<div style='text-align:left; font-size:40px; font-weight: bold;'>ELVIS Interactive Decision Boundary Visualizer</div>")
|
| 341 |
+
|
| 342 |
+
# GUI elements and layout
|
| 343 |
+
with gr.Row():
|
| 344 |
+
with gr.Column(scale=2):
|
| 345 |
+
self.data_image = gr.Image(value=self.plot(), container=True)
|
| 346 |
+
|
| 347 |
+
with gr.Column(scale=1):
|
| 348 |
+
with gr.Tab("Dataset"):
|
| 349 |
+
dataset_radio = gr.Radio(["Draw2D", "Upload", "sklearn"],
|
| 350 |
+
value="Draw2D", label="Dataset type", elem_id="rowheight")
|
| 351 |
+
|
| 352 |
+
# upload data
|
| 353 |
+
file_chooser = gr.File(label="Choose a file", visible=False, elem_id="rowheight")
|
| 354 |
+
self.file_chooser = file_chooser
|
| 355 |
+
|
| 356 |
+
# sklearn data dropdown menu
|
| 357 |
+
sklearn_data_selector = gr.Dropdown(choices=self.dataloaders,
|
| 358 |
+
label='Select dataset',
|
| 359 |
+
value='None',
|
| 360 |
+
visible=False,
|
| 361 |
+
allow_custom_value=True)
|
| 362 |
+
self.sklearn_data_selector = sklearn_data_selector
|
| 363 |
+
|
| 364 |
+
# embedder
|
| 365 |
+
embedder_selector = gr.Dropdown(choices=self.embedders,
|
| 366 |
+
label='Select embedder',
|
| 367 |
+
value='CoordinateProjection',
|
| 368 |
+
visible=False,
|
| 369 |
+
allow_custom_value=True)
|
| 370 |
+
|
| 371 |
+
# custom data
|
| 372 |
+
label = gr.Radio(["Red", "Green", "Blue"], value="Red", label="Choose point label", visible=True, elem_id="rowheight")
|
| 373 |
+
self.label = label
|
| 374 |
+
|
| 375 |
+
with gr.Row():
|
| 376 |
+
btn_clear = gr.Button("Clear", visible=True, elem_id="my-button")
|
| 377 |
+
self.btn_clear = btn_clear
|
| 378 |
+
|
| 379 |
+
btn_save = gr.Button("Save", visible=True, elem_id="my-button")
|
| 380 |
+
self.btn_save = btn_save
|
| 381 |
+
|
| 382 |
+
#data_html = gr.HTML(visible=True)
|
| 383 |
+
data_table = gr.Dataframe(visible=False)
|
| 384 |
+
|
| 385 |
+
# classifier selector
|
| 386 |
+
with gr.Tab("Classifier"):
|
| 387 |
+
# specify model
|
| 388 |
+
model_selector = gr.Dropdown(choices=self.classifiers,
|
| 389 |
+
#label='',
|
| 390 |
+
#value='Select classifier',
|
| 391 |
+
label='Select Classifier',
|
| 392 |
+
value='LinearSVC',
|
| 393 |
+
allow_custom_value=True)
|
| 394 |
+
self.model_selector = model_selector
|
| 395 |
+
|
| 396 |
+
# specify arguments
|
| 397 |
+
args_textbox = gr.Textbox(label="Classifier arguments")
|
| 398 |
+
self.args_textbox = args_textbox
|
| 399 |
+
|
| 400 |
+
model_selector.change(fn=self.update_model, inputs=model_selector, outputs=args_textbox)
|
| 401 |
+
|
| 402 |
+
btn_train = gr.Button("Train Model")
|
| 403 |
+
|
| 404 |
+
classification_summary = gr.HTML(visible=False)
|
| 405 |
+
|
| 406 |
+
with gr.Tab("Export"):
|
| 407 |
+
btn_export_data = gr.Button('Data')
|
| 408 |
+
|
| 409 |
+
btn_export_model = gr.Button('Model')
|
| 410 |
+
|
| 411 |
+
btn_export_code = gr.Button('Code')
|
| 412 |
+
|
| 413 |
+
with gr.Tab("Options"):
|
| 414 |
+
pass
|
| 415 |
+
|
| 416 |
+
with gr.Tab("Usage"):
|
| 417 |
+
pass
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
# event handlers for GUI elements
|
| 421 |
+
self.data_image.select(self.add_point, inputs=label,
|
| 422 |
+
outputs=(self.data_image, data_table))
|
| 423 |
+
|
| 424 |
+
dataset_radio.change(fn=self.handle_dataset_radio,
|
| 425 |
+
inputs=dataset_radio,
|
| 426 |
+
outputs=(file_chooser, sklearn_data_selector, embedder_selector, label, btn_clear, btn_save))
|
| 427 |
+
|
| 428 |
+
# events for custom dataset
|
| 429 |
+
btn_clear.click(fn=self.clear, outputs=self.data_image)
|
| 430 |
+
btn_save.click(fn=self.save)
|
| 431 |
+
|
| 432 |
+
# events for local dataset
|
| 433 |
+
file_chooser.change(fn=self.load_local_data_and_plot,
|
| 434 |
+
inputs=file_chooser,
|
| 435 |
+
outputs=(self.data_image, data_table))
|
| 436 |
+
|
| 437 |
+
# events for sklearn dataset
|
| 438 |
+
sklearn_data_selector.change(fn=self.load_sklearn_data_and_plot,
|
| 439 |
+
inputs=sklearn_data_selector,
|
| 440 |
+
outputs=(self.data_image, data_table))
|
| 441 |
+
|
| 442 |
+
embedder_selector.change(fn=self.update_embedder,
|
| 443 |
+
inputs=embedder_selector,
|
| 444 |
+
outputs=self.data_image)
|
| 445 |
+
|
| 446 |
+
btn_train.click(fn=self.update_args, inputs=args_textbox)
|
| 447 |
+
btn_train.click(fn=self.train, outputs=(self.data_image, classification_summary, classification_summary))
|
| 448 |
+
|
| 449 |
+
demo.launch()
|
| 450 |
+
|
| 451 |
+
visualizer = InteractiveDecisionBoundary(width=1200, height=900)
|
| 452 |
+
visualizer.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
matplotlib
|
| 2 |
+
numpy
|
| 3 |
+
pandas
|
| 4 |
+
scikit-learn
|
| 5 |
+
mpu
|
| 6 |
+
pillow
|
util.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import sklearn
|
| 3 |
+
from sklearn.linear_model import LogisticRegression
|
| 4 |
+
from sklearn.base import ClassifierMixin
|
| 5 |
+
from sklearn.datasets import *
|
| 6 |
+
import pkgutil
|
| 7 |
+
import importlib
|
| 8 |
+
import warnings
|
| 9 |
+
import ast
|
| 10 |
+
import pandas as pd
|
| 11 |
+
|
| 12 |
+
def safe_import_module(name):
|
| 13 |
+
try:
|
| 14 |
+
with warnings.catch_warnings():
|
| 15 |
+
warnings.simplefilter("ignore")
|
| 16 |
+
return importlib.import_module(name)
|
| 17 |
+
except Exception:
|
| 18 |
+
return None # or raise/log if desired
|
| 19 |
+
|
| 20 |
+
def get_sklearn_classifiers():
|
| 21 |
+
classifiers = {}
|
| 22 |
+
|
| 23 |
+
#for modname in dir(sklearn):
|
| 24 |
+
for _, modname, _ in pkgutil.walk_packages(sklearn.__path__, prefix="sklearn."):
|
| 25 |
+
if '._' in modname: # exclude hidden modules
|
| 26 |
+
continue
|
| 27 |
+
|
| 28 |
+
if modname.count('.') > 1: # exclude modules more than two levels deep
|
| 29 |
+
continue
|
| 30 |
+
|
| 31 |
+
#print(modname)
|
| 32 |
+
try:
|
| 33 |
+
#with warnings.catch_warnings():
|
| 34 |
+
#warnings.simplefilter("ignore")
|
| 35 |
+
|
| 36 |
+
module = importlib.import_module(modname)
|
| 37 |
+
for cls_name, cls in inspect.getmembers(module, inspect.isclass):
|
| 38 |
+
if '._' not in cls_name and ('ClassifierMixin' not in cls_name):
|
| 39 |
+
if issubclass(cls, ClassifierMixin) and cls.__module__.startswith("sklearn"):
|
| 40 |
+
classifiers[cls_name] = cls
|
| 41 |
+
#classifiers.append(f"{cls.__module__}.{cls_name}")
|
| 42 |
+
except:
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
return classifiers
|
| 46 |
+
|
| 47 |
+
def get_sklearn_dataloaders():
|
| 48 |
+
dataloaders = {}
|
| 49 |
+
dataloaders['20newsgroup'] = fetch_20newsgroups
|
| 50 |
+
dataloaders['20newsgroup_vectorized'] = fetch_20newsgroups_vectorized
|
| 51 |
+
dataloaders['covtype'] = fetch_covtype
|
| 52 |
+
dataloaders['kddcup99'] = fetch_kddcup99
|
| 53 |
+
dataloaders['iris'] = load_iris
|
| 54 |
+
|
| 55 |
+
return dataloaders
|
| 56 |
+
|
| 57 |
+
def parse_param_string(param_str):
|
| 58 |
+
param_str = param_str.replace("*,", "") # Remove '*' if present
|
| 59 |
+
params = {}
|
| 60 |
+
for item in param_str.split(','):
|
| 61 |
+
if not item.strip():
|
| 62 |
+
continue
|
| 63 |
+
if '=' not in item:
|
| 64 |
+
continue
|
| 65 |
+
key, value = item.split('=', 1)
|
| 66 |
+
key = key.strip()
|
| 67 |
+
try:
|
| 68 |
+
value = ast.literal_eval(value.strip())
|
| 69 |
+
except Exception:
|
| 70 |
+
value = value.strip() # fallback: treat as string
|
| 71 |
+
params[key] = value
|
| 72 |
+
return params
|
| 73 |
+
|
| 74 |
+
def read(filename):
|
| 75 |
+
if filename.endswith(".csv"):
|
| 76 |
+
return pd.read_csv(filename)
|
| 77 |
+
elif filename.endswith(".xlsx") or filename.endswith(".xls"):
|
| 78 |
+
return pd.read_excel(filename)
|
| 79 |
+
elif filename.endswith(".parquet"):
|
| 80 |
+
return pd.read_parquet(filename)
|
| 81 |
+
elif filename.endswith(".feather"):
|
| 82 |
+
return pd.read_feather(filename)
|
| 83 |
+
elif filename.endswith(".json"):
|
| 84 |
+
return pd.read_json(filename)
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError("Unsupported file format.")
|
| 87 |
+
|
| 88 |
+
if __name__ == '__main__':
|
| 89 |
+
#print(classifier_list)
|
| 90 |
+
s = "penalty='l2', *, dual=False, tol=0.0001, C=1.0, fit_intercept=True, intercept_scaling=1"
|
| 91 |
+
parsed = parse_param_string(s)
|
| 92 |
+
print(parsed)
|