joel-woodfield commited on
Commit
61730e7
·
1 Parent(s): c874a97

Implement code export for plot reproduction

Browse files

- note: support for passing model arguments is not yet implemented

decision_boundary.py CHANGED
@@ -1,9 +1,14 @@
 
 
 
 
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import inspect
4
  import numpy as np
5
  import pandas as pd
6
  import io
 
7
  from PIL import Image
8
  import sklearn
9
  from sklearn.linear_model import LogisticRegression
@@ -13,9 +18,7 @@ from sklearn.datasets import load_iris
13
  from sklearn.decomposition import PCA
14
  from sklearn.metrics import classification_report
15
  import traceback
16
-
17
- from collections import deque
18
- import pickle
19
 
20
  from util import *
21
 
@@ -60,14 +63,17 @@ class CoordinateProjection:
60
  def __init__(self, n_components=None, dims=[0, 1]):
61
  self.dims = dims
62
 
63
- def fit_transform(self, X):
64
- print(X)
65
  self.mean = X.mean(axis=0)
66
- return X[:, self.dims]
67
 
68
  def transform(self, X):
69
  return X[:, self.dims]
70
 
 
 
 
 
71
  def inverse_transform(self, Z):
72
  X = np.ones((len(Z), 1)) * self.mean
73
  X[:, self.dims] = Z
@@ -77,6 +83,9 @@ class InteractiveDecisionBoundary:
77
  DATASET_FILE = "dataset.csv"
78
  MODEL_FILE = "model.pkl"
79
 
 
 
 
80
  def __init__(self, width, height):
81
  # initialized in draw_plot
82
  #self.canvas_width = -1
@@ -97,13 +106,15 @@ class InteractiveDecisionBoundary:
97
  for cls_name, cls in inspect.getmembers(module, inspect.isclass):
98
  self.embedders[cls_name] = cls
99
 
100
- self.Embedder = CoordinateProjection
101
- #self.Embedder = PCA
 
102
 
103
  # default classifier model
104
- #self.model = LogisticRegression
105
- self.model = LinearSVC
106
  self.model_args = ""
 
107
 
108
  # todo: support arbitrary number of classes and user-defined class labels
109
  #self.dataset = toydata()
@@ -168,7 +179,7 @@ class InteractiveDecisionBoundary:
168
  logger.info("Target:\n" + str(y))
169
 
170
  # compute embedding
171
- embedder = self.Embedder(n_components=2)
172
  self.embedder = embedder
173
  logger.info("Embedder = " + str(self.embedder))
174
 
@@ -192,7 +203,7 @@ class InteractiveDecisionBoundary:
192
 
193
  # plot the decision boundary
194
  if decision_boundary:
195
- model = self.model(**parse_param_string(self.model_args))
196
  model.fit(X, y)
197
  self.model = model
198
 
@@ -290,9 +301,9 @@ class InteractiveDecisionBoundary:
290
  return self.DATASET_FILE
291
 
292
  def update_model(self, classifier_name):
293
- self.model = self.classifiers[classifier_name]
294
  self.args_textbox.value = ""
295
- logger.info(f'Updated model to {self.model}')
296
 
297
  return ""
298
 
@@ -302,13 +313,64 @@ class InteractiveDecisionBoundary:
302
  logger.info(f"{self.MODEL_FILE} updated")
303
  return self.MODEL_FILE
304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  def update_args(self, model_args):
306
  self.model_args = model_args
307
  print('updated model_args:', self.model_args)
308
 
309
  def update_embedder(self, embedder):
310
- self.Embedder = self.embedders[embedder]
311
- print('updated embedder:', self.Embedder)
312
  return self.plot()
313
 
314
  def handle_dataset_radio(self, type):
@@ -446,6 +508,7 @@ class InteractiveDecisionBoundary:
446
  btn_export_model_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_model_hidden", elem_classes="hidden-button")
447
 
448
  btn_export_code = gr.Button('Code')
 
449
 
450
  with gr.Tab("Options"):
451
  slider = gr.Slider(minimum=100, maximum=1000, value=100, step=1, label="Resolution (#points)")
@@ -500,6 +563,13 @@ class InteractiveDecisionBoundary:
500
  ).then(
501
  fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_model_hidden').click()"
502
  )
 
 
 
 
 
 
 
503
 
504
 
505
  demo.launch()
 
1
+ from collections import deque
2
+ from pathlib import Path
3
+ import pickle
4
+
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
7
  import inspect
8
  import numpy as np
9
  import pandas as pd
10
  import io
11
+ from jinja2 import Template
12
  from PIL import Image
13
  import sklearn
14
  from sklearn.linear_model import LogisticRegression
 
18
  from sklearn.decomposition import PCA
19
  from sklearn.metrics import classification_report
20
  import traceback
21
+ import yaml
 
 
22
 
23
  from util import *
24
 
 
63
  def __init__(self, n_components=None, dims=[0, 1]):
64
  self.dims = dims
65
 
66
+ def fit(self, X):
 
67
  self.mean = X.mean(axis=0)
68
+ return self
69
 
70
  def transform(self, X):
71
  return X[:, self.dims]
72
 
73
+ def fit_transform(self, X):
74
+ self.fit(X)
75
+ return self.transform(X)
76
+
77
  def inverse_transform(self, Z):
78
  X = np.ones((len(Z), 1)) * self.mean
79
  X[:, self.dims] = Z
 
83
  DATASET_FILE = "dataset.csv"
84
  MODEL_FILE = "model.pkl"
85
 
86
+ CODE_FILE = "generated_code.py"
87
+ EXPORT_CODE_TEMPLATE = "export_code_template.py.j2"
88
+
89
  def __init__(self, width, height):
90
  # initialized in draw_plot
91
  #self.canvas_width = -1
 
106
  for cls_name, cls in inspect.getmembers(module, inspect.isclass):
107
  self.embedders[cls_name] = cls
108
 
109
+ self.embedder_class = CoordinateProjection
110
+ #self.embedder_class = PCA
111
+ self.embedder = self.embedder_class()
112
 
113
  # default classifier model
114
+ #self.model_class = LogisticRegression
115
+ self.model_class = LinearSVC
116
  self.model_args = ""
117
+ self.model = self.model_class()
118
 
119
  # todo: support arbitrary number of classes and user-defined class labels
120
  #self.dataset = toydata()
 
179
  logger.info("Target:\n" + str(y))
180
 
181
  # compute embedding
182
+ embedder = self.embedder_class(n_components=2)
183
  self.embedder = embedder
184
  logger.info("Embedder = " + str(self.embedder))
185
 
 
203
 
204
  # plot the decision boundary
205
  if decision_boundary:
206
+ model = self.model_class(**parse_param_string(self.model_args))
207
  model.fit(X, y)
208
  self.model = model
209
 
 
301
  return self.DATASET_FILE
302
 
303
  def update_model(self, classifier_name):
304
+ self.model_class = self.classifiers[classifier_name]
305
  self.args_textbox.value = ""
306
+ logger.info(f'Updated model to {self.model_class}')
307
 
308
  return ""
309
 
 
313
  logger.info(f"{self.MODEL_FILE} updated")
314
  return self.MODEL_FILE
315
 
316
+ def save_code(self):
317
+ model_class = str(self.model_class.__name__)
318
+ model_imports = yaml.safe_load(Path("model_imports.yaml").read_text())
319
+ if model_class not in model_imports:
320
+ raise ValueError(f"Model {model_class} not found in model_imports.yaml")
321
+ model_import_stmt = model_imports[model_class]
322
+
323
+ embedder_class = str(self.embedder_class.__name__)
324
+ if embedder_class == "CoordinateProjection":
325
+ embedder_import_stmt = f"\n{inspect.getsource(CoordinateProjection)}"
326
+ else:
327
+ embedder_imports = yaml.safe_load(Path("embedder_imports.yaml").read_text())
328
+ if embedder_class not in embedder_imports:
329
+ raise ValueError(f"Embedder {embedder_class} not found in embedder_imports.yaml")
330
+ embedder_import_stmt = embedder_imports[embedder_class]
331
+
332
+ if self.dataset_type == 'Draw2D':
333
+ x_min = 0
334
+ x_max = 1
335
+ y_min = 0
336
+ y_max = 1
337
+ else:
338
+ X = self.dataset.loc[:, self.dataset.columns != 'target'].values
339
+ Z = self.embedder.fit_transform(X)
340
+ x_min = Z[:, 0].min()
341
+ x_max = Z[:, 0].max()
342
+ y_min = Z[:, 1].min()
343
+ y_max = Z[:, 1].max()
344
+
345
+ template = Template(Path(self.EXPORT_CODE_TEMPLATE).read_text())
346
+ variables = {
347
+ 'model_import_statement': model_import_stmt,
348
+ 'embedder_import_statement': embedder_import_stmt,
349
+ 'dataset_file': self.DATASET_FILE,
350
+ 'embedder_class': embedder_class,
351
+ 'model_class': model_class,
352
+ 'fig_width': self.canvas_width / 100,
353
+ 'fig_height': self.canvas_height / 100,
354
+ 'dpi': 100,
355
+ 'num_dots': self.num_dots,
356
+ 'x_min': x_min,
357
+ 'x_max': x_max,
358
+ 'y_min': y_min,
359
+ 'y_max': y_max,
360
+ }
361
+
362
+ rendered_code = template.render(variables)
363
+ Path(self.CODE_FILE).write_text(rendered_code)
364
+ logger.info(f"{self.CODE_FILE} updated")
365
+ return self.CODE_FILE
366
+
367
  def update_args(self, model_args):
368
  self.model_args = model_args
369
  print('updated model_args:', self.model_args)
370
 
371
  def update_embedder(self, embedder):
372
+ self.embedder_class = self.embedders[embedder]
373
+ print('updated Embedder:', self.embedder_class)
374
  return self.plot()
375
 
376
  def handle_dataset_radio(self, type):
 
508
  btn_export_model_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_model_hidden", elem_classes="hidden-button")
509
 
510
  btn_export_code = gr.Button('Code')
511
+ btn_export_code_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_code_hidden", elem_classes="hidden-button")
512
 
513
  with gr.Tab("Options"):
514
  slider = gr.Slider(minimum=100, maximum=1000, value=100, step=1, label="Resolution (#points)")
 
563
  ).then(
564
  fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_model_hidden').click()"
565
  )
566
+ btn_export_code.click(
567
+ fn=self.save_code,
568
+ inputs=None,
569
+ outputs=[btn_export_code_hidden]
570
+ ).then(
571
+ fn=None, inputs=None, outputs=None, js="() => document.querySelector('#btn_export_code_hidden').click()"
572
+ )
573
 
574
 
575
  demo.launch()
embedder_imports.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GaussianRandomProjection: from sklearn.random_projection import GaussianRandomProjection
2
+ SparseRandomProjection: from sklearn.random_projection import SparseRandomProjection
3
+ DictionaryLearning: from sklearn.decomposition import DictionaryLearning
4
+ FactorAnalysis: from sklearn.decomposition import FactorAnalysis
5
+ FastICA: from sklearn.decomposition import FastICA
6
+ IncrementalPCA: from sklearn.decomposition import IncrementalPCA
7
+ KernelPCA: from sklearn.decomposition import KernelPCA
8
+ LatentDirichletAllocation: from sklearn.decomposition import LatentDirichletAllocation
9
+ MiniBatchDictionaryLearning: from sklearn.decomposition import MiniBatchDictionaryLearning
10
+ MiniBatchNMF: from sklearn.decomposition import MiniBatchNMF
11
+ MiniBatchSparsePCA: from sklearn.decomposition import MiniBatchSparsePCA
12
+ NMF: from sklearn.decomposition import NMF
13
+ PCA: from sklearn.decomposition import PCA
14
+ SparsePCA: from sklearn.decomposition import SparsePCA
15
+ TruncatedSVD: from sklearn.decomposition import TruncatedSVD
export_code_template.py.j2 ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import pandas as pd
5
+ {{ model_import_statement }}
6
+ {{ embedder_import_statement }}
7
+
8
+
9
+ def label2color(labels):
10
+ cmap = plt.get_cmap("tab10")
11
+ named_colors = set(matplotlib.colors.CSS4_COLORS.keys())
12
+ colors = []
13
+ for i, label in enumerate(labels):
14
+ if label.lower() in named_colors:
15
+ colors.append(label.lower())
16
+ else:
17
+ colors.append(cmap(i % 10))
18
+ return colors
19
+
20
+
21
+ def load_dataset(file_path):
22
+ """Load dataset from a CSV file."""
23
+ data = pd.read_csv(file_path)
24
+ X = data.loc[:, data.columns != 'target'].values
25
+ y = data['target'].values.astype(str)
26
+
27
+ if len(X) == 0:
28
+ raise ValueError("The dataset is empty or not properly formatted.")
29
+
30
+ return X, y
31
+
32
+
33
+ def create_plot(X, y, model, embedder):
34
+ # plot data points
35
+ labels = np.unique(y)
36
+ colors = label2color(labels)
37
+ l2c = dict(zip(labels, colors))
38
+
39
+ X_embedded = embedder.transform(X)
40
+
41
+ for i, label in enumerate(labels):
42
+ subset = X_embedded[y == label]
43
+ plt.scatter(subset[:, 0], subset[:, 1], color=colors[i], label=label)
44
+ plt.legend()
45
+
46
+ # plot decision regions
47
+ xx, yy = np.meshgrid(
48
+ np.linspace({{ x_min }}, {{ x_max }}, {{ num_dots }}),
49
+ np.linspace({{ y_min }}, {{ y_max }}, {{ num_dots }}),
50
+ )
51
+ xx = xx.ravel()
52
+ yy = yy.ravel()
53
+
54
+ grid = np.c_[xx, yy]
55
+ predictions = model.predict(embedder.inverse_transform(grid))
56
+ predictions = predictions.ravel()
57
+ plt.scatter(xx, yy, c=[l2c[p] for p in predictions], s=1, alpha=0.5)
58
+
59
+
60
+ def main():
61
+ # data loading and preprocessing
62
+ X, y = load_dataset("{{ dataset_file }}")
63
+ embedder = {{ embedder_class }}(n_components=2).fit(X)
64
+
65
+ # model training
66
+ model = {{ model_class }}()
67
+ model.fit(X, y)
68
+
69
+ create_plot(X, y, model, embedder)
70
+ plt.show()
71
+ # uncomment the line below if you want to save as image
72
+ # plt.savefig("plot.png")
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()
model_imports.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LinearSVC: from sklearn.svm import LinearSVC
2
+ CalibratedClassifierCV: from sklearn.calibration import CalibratedClassifierCV
3
+ LinearDiscriminantAnalysis: from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
4
+ QuadraticDiscriminantAnalysis: from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
5
+ DummyClassifier: from sklearn.dummy import DummyClassifier
6
+ AdaBoostClassifier: from sklearn.ensemble import AdaBoostClassifier
7
+ BaggingClassifier: from sklearn.ensemble import BaggingClassifier
8
+ ExtraTreesClassifier: from sklearn.ensemble import ExtraTreesClassifier
9
+ GradientBoostingClassifier: from sklearn.ensemble import GradientBoostingClassifier
10
+ HistGradientBoostingClassifier: from sklearn.ensemble import HistGradientBoostingClassifier
11
+ RandomForestClassifier: from sklearn.ensemble import RandomForestClassifier
12
+ GaussianProcessClassifier: from sklearn.gaussian_process import GaussianProcessClassifier
13
+ LogisticRegression: from sklearn.linear_model import LogisticRegression
14
+ PassiveAggressiveClassifier: from sklearn.linear_model import PassiveAggressiveClassifier
15
+ Perceptron: from sklearn.linear_model import Perceptron
16
+ RidgeClassifier: from sklearn.linear_model import RidgeClassifier
17
+ SGDClassifier: from sklearn.linear_model import SGDClassifier
18
+ GaussianNB: from sklearn.naive_bayes import GaussianNB
19
+ KNeighborsClassifier: from sklearn.neighbors import KNeighborsClassifier
20
+ NearestCentroid: from sklearn.neighbors import NearestCentroid
21
+ RadiusNeighborsClassifier: from sklearn.neighbors import RadiusNeighborsClassifier
22
+ MLPClassifier: from sklearn.neural_network import MLPClassifier
23
+ LabelPropagation: from sklearn.semi_supervised import LabelPropagation
24
+ LabelSpreading: from sklearn.semi_supervised import LabelSpreading
25
+ NuSVC: from sklearn.svm import NuSVC
26
+ SVC: from sklearn.svm import SVC
27
+ DecisionTreeClassifier: from sklearn.tree import DecisionTreeClassifier
28
+ ExtraTreeClassifier: from sklearn.tree import ExtraTreeClassifier