joel-woodfield commited on
Commit
16aafd3
·
1 Parent(s): a13fdc8

Add plotting and dataset options

Browse files
Files changed (3) hide show
  1. dataset_options.py +257 -0
  2. mlp_visualizer.py +34 -559
  3. mlp_visualizer_old.py +662 -0
dataset_options.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import numexpr
4
+ import pandas as pd
5
+ import time
6
+
7
+
8
+ NUMEXPR_CONSTANTS = {
9
+ 'pi': np.pi,
10
+ 'PI': np.pi,
11
+ 'e': np.e,
12
+ }
13
+
14
+
15
+ def get_function(function, xlim=(-1, 1), nsample=100):
16
+ x = np.linspace(xlim[0], xlim[1], nsample)
17
+ y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS})
18
+ x = x.reshape(-1, 1)
19
+ return x, y
20
+
21
+
22
+
23
+ def get_data_points(function, xlim=(-1, 1), nsample=10, sigma=0, seed=0):
24
+ num_points_to_generate = 100
25
+ if nsample > num_points_to_generate:
26
+ raise ValueError(f"nsample too large, limit to {num_points_to_generate}")
27
+
28
+ rng = np.random.default_rng(seed)
29
+ x = rng.uniform(xlim[0], xlim[1], size=num_points_to_generate)
30
+ x = x[:nsample]
31
+ x = np.sort(x)
32
+
33
+ rng = np.random.default_rng(seed)
34
+ noise = sigma * rng.standard_normal(nsample)
35
+ y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS}) + noise
36
+
37
+ x = x.reshape(-1, 1)
38
+ return x, y
39
+
40
+
41
+ class DatasetOptions:
42
+ def __init__(
43
+ self,
44
+ mode: str = "generate",
45
+ function: str = "x ** 2",
46
+ xmin: float = -1.0,
47
+ xmax: float = 1.0,
48
+ nsample: int = 30,
49
+ sigma: float = 0.0,
50
+ seed: int = 0,
51
+ csv_path: str = None,
52
+ ):
53
+ self.mode = mode
54
+
55
+ self.function = function
56
+ self.xmin = xmin
57
+ self.xmax = xmax
58
+ self.nsample = nsample
59
+ self.sigma = sigma
60
+ self.seed = seed
61
+
62
+ self.csv_path = csv_path
63
+
64
+ self.x, self.y = self._get_data()
65
+
66
+ def _get_data(self):
67
+ if self.mode == "generate":
68
+ return get_data_points(
69
+ function=self.function,
70
+ xlim=(self.xmin, self.xmax),
71
+ nsample=self.nsample,
72
+ sigma=self.sigma,
73
+ seed=self.seed,
74
+ )
75
+
76
+ elif self.mode == "csv":
77
+ if self.csv_path is None:
78
+ return np.array([]), np.array([])
79
+
80
+ df = pd.read_csv(self.csv_path)
81
+ if df.shape[1] != 2:
82
+ raise ValueError("CSV file must have exactly two columns")
83
+
84
+ x = df.iloc[:, 0].values.reshape(-1, 1)
85
+ y = df.iloc[:, 1].values
86
+ return x, y
87
+
88
+ else:
89
+ raise ValueError(f"Unknown dataset mode: {self.mode}")
90
+
91
+ def update(self, **kwargs):
92
+ return DatasetOptions(
93
+ mode=kwargs.get("mode", self.mode),
94
+ function=kwargs.get("function", self.function),
95
+ xmin=kwargs.get("xmin", self.xmin),
96
+ xmax=kwargs.get("xmax", self.xmax),
97
+ nsample=kwargs.get("nsample", self.nsample),
98
+ sigma=kwargs.get("sigma", self.sigma),
99
+ seed=kwargs.get("seed", self.seed),
100
+ csv_path=kwargs.get("csv_path", self.csv_path),
101
+ )
102
+
103
+ def _safe_hash(self, val: int) -> int | tuple[int, str]:
104
+ # special handling for -1 (same hash number as -2)
105
+ if val == -1:
106
+ return (-1, "special")
107
+ return val
108
+
109
+ def __hash__(self):
110
+ return hash(
111
+ (
112
+ self.mode,
113
+ self.function,
114
+ self._safe_hash(self.xmin),
115
+ self._safe_hash(self.xmax),
116
+ self.nsample,
117
+ self.sigma,
118
+ self.seed,
119
+ self.csv_path,
120
+ )
121
+ )
122
+
123
+
124
+ class DatasetOptionsView:
125
+ def update_mode(self, mode: str, state: gr.State):
126
+ state = state.update(mode=mode)
127
+
128
+ if mode == "generate":
129
+ return (
130
+ state,
131
+ gr.update(visible=True), # function
132
+ gr.update(visible=True), # xmin
133
+ gr.update(visible=True), # xmax
134
+ gr.update(visible=True), # sigma
135
+ gr.update(visible=True), # nsample
136
+ gr.update(visible=True), # regenerate
137
+ gr.update(visible=False), # csv upload
138
+ )
139
+ elif mode == "csv":
140
+ return (
141
+ state,
142
+ gr.update(visible=False), # function
143
+ gr.update(visible=False), # xmin
144
+ gr.update(visible=False), # xmax
145
+ gr.update(visible=False), # sigma
146
+ gr.update(visible=False), # nsample
147
+ gr.update(visible=False), # regenerate
148
+ gr.update(visible=True), # csv upload
149
+ )
150
+ else:
151
+ raise ValueError(f"Unknown mode: {mode}")
152
+
153
+ def upload_csv(self, file, state):
154
+ try:
155
+ state = state.update(
156
+ mode="csv",
157
+ csv_path=file.name,
158
+ )
159
+
160
+ except Exception as e:
161
+ gr.Info(f"⚠️ {e}")
162
+
163
+ return state
164
+
165
+ def regenerate_data(self, state: gr.State):
166
+ seed = int(time.time() * 1000) % (2 ** 32)
167
+ state = state.update(seed=seed)
168
+ return state
169
+
170
+ def build(self, state: gr.State):
171
+ options = state.value
172
+
173
+ with gr.Column():
174
+ mode = gr.Radio(
175
+ label="Dataset",
176
+ choices=["generate", "csv"],
177
+ value="generate",
178
+ )
179
+
180
+ function = gr.Textbox(
181
+ label="Function (in terms of x)",
182
+ value=options.function,
183
+ )
184
+ with gr.Row():
185
+ xmin = gr.Number(
186
+ label="X min",
187
+ value=options.xmin,
188
+ )
189
+ xmax = gr.Number(
190
+ label="X max",
191
+ value=options.xmax,
192
+ )
193
+ sigma = gr.Number(
194
+ label="Gaussian noise standard deviation",
195
+ value=options.sigma,
196
+ )
197
+ nsample = gr.Slider(
198
+ label="Number of samples",
199
+ minimum=1,
200
+ maximum=100,
201
+ step=1,
202
+ value=options.nsample,
203
+ )
204
+ regenerate = gr.Button("Regenerate Data")
205
+
206
+ csv_upload = gr.File(
207
+ label="Upload CSV file",
208
+ file_types=['.csv'],
209
+ visible=False, # function mode is default
210
+ )
211
+
212
+ mode.change(
213
+ fn=self.update_mode,
214
+ inputs=[mode, state],
215
+ outputs=[state, function, xmin, xmax, sigma, nsample, regenerate, csv_upload],
216
+ )
217
+
218
+ # function
219
+ function.submit(
220
+ lambda f, s: s.update(function=f),
221
+ inputs=[function, state],
222
+ outputs=[state],
223
+ )
224
+ xmin.submit(
225
+ lambda xmn, s: s.update(xmin=xmn),
226
+ inputs=[xmin, state],
227
+ outputs=[state],
228
+ )
229
+ xmax.submit(
230
+ lambda xmx, s: s.update(xmax=xmx),
231
+ inputs=[xmax, state],
232
+ outputs=[state],
233
+ )
234
+ sigma.submit(
235
+ lambda sig, s: s.update(sigma=sig),
236
+ inputs=[sigma, state],
237
+ outputs=[state],
238
+ )
239
+ nsample.change(
240
+ lambda n, s: s.update(nsample=n),
241
+ inputs=[nsample, state],
242
+ outputs=[state],
243
+ )
244
+ regenerate.click(
245
+ fn=self.regenerate_data,
246
+ inputs=[state],
247
+ outputs=[state],
248
+ )
249
+
250
+ # csv upload
251
+ csv_upload.upload(
252
+ self.upload_csv,
253
+ inputs=[csv_upload, state],
254
+ outputs=[state],
255
+ )
256
+
257
+
mlp_visualizer.py CHANGED
@@ -1,4 +1,5 @@
1
  from collections import deque
 
2
  import functools
3
  from pathlib import Path
4
  import pickle
@@ -28,190 +29,14 @@ logging.basicConfig(
28
  )
29
  logger = logging.getLogger("ELVIS")
30
 
31
-
32
- NUMEXPR_CONSTANTS = {
33
- 'pi': np.pi,
34
- 'PI': np.pi,
35
- 'e': np.e,
36
- }
37
-
38
-
39
- def get_function(function, xlim=(-1, 1), nsample=100):
40
- x = np.linspace(xlim[0], xlim[1], nsample)
41
- y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS})
42
- x = x.reshape(-1, 1)
43
- return x, y
44
-
45
-
46
- def get_data_points(function, xlim=(-1, 1), nsample=10, sigma=0, seed=0):
47
- num_points_to_generate = 100
48
- if nsample > num_points_to_generate:
49
- raise ValueError(f"nsample too large, limit to {num_points_to_generate}")
50
-
51
- rng = np.random.default_rng(seed)
52
- x = rng.uniform(xlim[0], xlim[1], size=num_points_to_generate)
53
- x = x[:nsample]
54
- x = np.sort(x)
55
-
56
- rng = np.random.default_rng(seed)
57
- noise = sigma * rng.standard_normal(nsample)
58
- y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS}) + noise
59
-
60
- x = x.reshape(-1, 1)
61
- return x, y
62
-
63
-
64
- class HiddenLayerBox:
65
- def __init__(self, initially_visible=False):
66
- with gr.Row():
67
- self.hidden_units = gr.Number(label="Hidden units", value=64, visible=initially_visible)
68
- self.activation = gr.Textbox(label="Activation", value="ReLU", visible=initially_visible)
69
-
70
- def set_visibility(self, visible):
71
- return [
72
- gr.update(visible=visible),
73
- gr.update(visible=visible),
74
- ]
75
-
76
- def get_values(self):
77
- return [self.hidden_units, self.activation]
78
-
79
-
80
- class ArchitectureComponent:
81
- def __init__(self, update_architecture_callback, canvas, max_layers=5):
82
- self.num_show = 2
83
- self.components = []
84
- for i in range(max_layers):
85
- comp = HiddenLayerBox(initially_visible=(i < self.num_show))
86
- self.components.append(comp)
87
-
88
- self.update_architecture_callback = update_architecture_callback
89
- self.canvas = canvas
90
-
91
- def update_architecture(self, *values):
92
- # values come as [hidden1, act1, hidden2, act2, ...]
93
- hidden_layers = []
94
- activations = []
95
- for i in range(0, self.num_show * 2, 2):
96
- if values[i] != "" or values[i + 1] != "":
97
- hidden_layers.append(values[i])
98
- activations.append(values[i + 1])
99
- return self.update_architecture_callback(hidden_layers, activations)
100
-
101
- def build(self):
102
- with gr.Row():
103
- add_btn = gr.Button("Add layer")
104
- remove_btn = gr.Button("Remove layer")
105
-
106
- with gr.Row():
107
- gr.Number(label="Output units", value=1, interactive=False)
108
- gr.Textbox(label="Activation", value="Identity", interactive=False)
109
-
110
- # Collect all subcomponents
111
- all_outputs = []
112
- for comp in self.components:
113
- all_outputs += [comp.hidden_units, comp.activation]
114
-
115
- def on_add():
116
- self.num_show = min(self.num_show + 1, len(self.components))
117
- updates = []
118
- for i, comp in enumerate(self.components):
119
- updates += comp.set_visibility(i < self.num_show)
120
-
121
- updates += [gr.update(value=self.num_show)]
122
- return updates
123
-
124
- def on_remove():
125
- self.num_show = max(self.num_show - 1, 0)
126
- updates = []
127
- for i, comp in enumerate(self.components):
128
- updates += comp.set_visibility(i < self.num_show)
129
-
130
- updates += [gr.update(value=self.num_show)]
131
- return updates
132
-
133
- hidden_counter = gr.Number(value=self.num_show, visible=False)
134
-
135
- add_btn.click(on_add, outputs=[*all_outputs, hidden_counter] )
136
- remove_btn.click(on_remove, outputs=[*all_outputs, hidden_counter] )
137
-
138
- for output in all_outputs:
139
- output.submit(
140
- fn=self.update_architecture,
141
- inputs=all_outputs,
142
- outputs=[self.canvas],
143
- )
144
- hidden_counter.change(
145
- fn=self.update_architecture,
146
- inputs=all_outputs,
147
- outputs=[self.canvas],
148
- )
149
 
150
 
151
  class MlpVisualizer:
152
- DEFAULT_FUNCTION = "sin(2 * pi * x)"
153
-
154
- DEFAULT_OPTIMIZER = "SGD"
155
- DEFAULT_LEARNING_RATE = 0.01
156
-
157
- DEFAULT_OPTIMIZER_HPARAMS = {
158
- "SGD": {
159
- "learning_rate": 0.1,
160
- "momentum": 0.0,
161
- },
162
- "Adam": {
163
- "learning_rate": 0.01,
164
- "beta1": 0.9,
165
- "beta2": 0.999,
166
- "eps": 1e-8,
167
- },
168
- }
169
-
170
- def _init_state(self):
171
- self.data_options = {
172
- "function": self.DEFAULT_FUNCTION,
173
- "nsample": 30,
174
- "sigma": 0,
175
- "seed": 0,
176
- "x_min": -1,
177
- "x_max": 1,
178
- }
179
- self.x_train, self.y_train = self.generate_data()
180
-
181
- self.architecture_options = {
182
- "hidden_layers": [64, 64],
183
- "activations": ["ReLU", "ReLU"],
184
- }
185
- self.basic_train_hparams = {
186
- "batch_size": self.x_train.shape[0],
187
- "optimizer": self.DEFAULT_OPTIMIZER,
188
- }
189
-
190
- # important to copy dict
191
- self.optimizer_hparams = {}
192
- for opt, params in self.DEFAULT_OPTIMIZER_HPARAMS.items():
193
- self.optimizer_hparams[opt] = params.copy()
194
-
195
- # do not initialise here, otherwise gradio will make it not work
196
- # self.param_components = {}
197
-
198
- self.criterion = nn.MSELoss()
199
- self.model, self.optimizer, self.train_loss = self.init_model()
200
- self.num_steps_trained = 0
201
-
202
-
203
- self.plot_options = {
204
- "show_training_data": True,
205
- "show_true_function": True,
206
- "show_predictions": True,
207
- }
208
-
209
  def __init__(self, width, height):
210
  self.canvas_width = width
211
  self.canvas_height = height
212
 
213
- self._init_state()
214
-
215
  self.plot_cmap = plt.get_cmap("tab20")
216
 
217
  self.css = """
@@ -219,97 +44,38 @@ class MlpVisualizer:
219
  display: none;
220
  }"""
221
 
222
- def on_load(self):
223
- self._init_state()
224
-
225
- def generate_data(self):
226
- function = self.data_options["function"]
227
- nsample = self.data_options["nsample"]
228
- sigma = self.data_options["sigma"]
229
- x_min = self.data_options["x_min"]
230
- x_max = self.data_options["x_max"]
231
-
232
- return get_data_points(function, xlim=(x_min, x_max), nsample=nsample, sigma=sigma, seed=self.data_options["seed"])
233
-
234
- def init_model(self):
235
- print(self.architecture_options)
236
- layers = []
237
- input_size = 1
238
- for hidden_units, activation in zip(self.architecture_options["hidden_layers"], self.architecture_options["activations"]):
239
- layers.append(nn.Linear(input_size, hidden_units))
240
- if activation == "ReLU":
241
- layers.append(nn.ReLU())
242
- elif activation == "Sigmoid":
243
- layers.append(nn.Sigmoid())
244
- elif activation == "Tanh":
245
- layers.append(nn.Tanh())
246
- elif activation == "LeakyReLU":
247
- layers.append(nn.LeakyReLU())
248
- elif activation == "Identity":
249
- layers.append(nn.Identity())
250
- else:
251
- raise ValueError(f"Unsupported activation: {activation}")
252
- input_size = hidden_units
253
-
254
- output_layer = nn.Linear(input_size, 1)
255
- model = nn.Sequential(*layers, output_layer)
256
-
257
- if self.basic_train_hparams["optimizer"] == "Adam":
258
- optimizer = torch.optim.Adam(
259
- model.parameters(),
260
- lr=self.optimizer_hparams["Adam"]["learning_rate"],
261
- betas=(self.optimizer_hparams["Adam"]["beta1"], self.optimizer_hparams["Adam"]["beta2"]),
262
- eps=self.optimizer_hparams["Adam"]["eps"],
263
- )
264
- elif self.basic_train_hparams["optimizer"] == "SGD":
265
- optimizer = torch.optim.SGD(
266
- model.parameters(),
267
- lr=self.optimizer_hparams["SGD"]["learning_rate"],
268
- momentum=self.optimizer_hparams["SGD"]["momentum"],
269
- )
270
- else:
271
- raise ValueError(f"Unsupported optimizer: {self.basic_train_hparams['optimizer']}")
272
-
273
- self.num_steps_trained = 0
274
-
275
- # compute initial train loss
276
- model.eval()
277
- inputs = torch.from_numpy(self.x_train).float()
278
- targets = torch.from_numpy(self.y_train).float().unsqueeze(1)
279
- with torch.no_grad():
280
- outputs = model(inputs)
281
- train_loss = self.criterion(outputs, targets).item()
282
-
283
- return model, optimizer, train_loss
284
-
285
- def plot(self):
286
- '''
287
- '''
288
  t1 = time.time()
289
- logger.info("Initializing figure")
290
- fig = plt.figure(figsize=(self.canvas_width/100., self.canvas_height/100.0), dpi=100)
291
  # set entire figure to be the canvas to allow simple conversion of mouse
292
  # position to coordinates in the figure
293
  ax = fig.add_axes([0., 0., 1., 1.]) #
294
  ax.margins(x=0, y=0) # no padding in both directions
295
 
296
- x_test, y_test = get_function(self.data_options["function"], xlim=(-2, 2), nsample=100)
297
- y_pred = self.model(torch.from_numpy(x_test).float()).detach().numpy()
 
 
298
 
299
  # plot
300
  fig, ax = plt.subplots(figsize=(8, 8))
301
  ax.set_title("")
302
  ax.set_xlabel("x")
303
  ax.set_ylabel("y")
304
- ax.set_ylim(y_test.min() - 1, y_test.max() + 1)
305
 
306
- if self.plot_options["show_training_data"]:
307
- plt.scatter(self.x_train.flatten(), self.y_train, label='training data', color=self.plot_cmap(0))
 
 
 
 
 
308
 
309
- if self.plot_options["show_true_function"]:
310
  plt.plot(x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1))
311
 
312
- if self.plot_options["show_predictions"]:
313
  plt.plot(x_test.flatten(), y_pred, linestyle="--", label='prediction', color=self.plot_cmap(2))
314
 
315
  plt.legend()
@@ -325,335 +91,44 @@ class MlpVisualizer:
325
 
326
  return img
327
 
328
- def _update_data_seed(self):
329
- self.data_options["seed"] += 1
330
- self.x_train, self.y_train = self.generate_data()
331
- self.reset_model()
332
- return self.plot(), self.num_steps_trained, self.train_loss
333
-
334
- def reset_model(self):
335
- self.model, self.optimizer, self.train_loss = self.init_model()
336
- return self.plot(), self.num_steps_trained, self.train_loss
337
-
338
- def update_data_options(self, **kwargs):
339
- for key, value in kwargs.items():
340
- if key in self.data_options:
341
-
342
- # if function - test if valid
343
- if key == "function":
344
- try:
345
- x = np.linspace(-1, 1, 10)
346
- y = numexpr.evaluate(value, local_dict={'x': x, **NUMEXPR_CONSTANTS})
347
- except Exception as e:
348
- raise ValueError(f"Invalid function: {e}")
349
-
350
- self.data_options[key] = value
351
-
352
- # reset data and model
353
- self.x_train, self.y_train = self.generate_data()
354
- self.reset_model()
355
-
356
- if "nsample" in kwargs:
357
- slider_update = gr.update(maximum=self.x_train.shape[0], value=min(self.basic_train_hparams["batch_size"], self.x_train.shape[0]))
358
- return self.plot(), slider_update, self.num_steps_trained, self.train_loss
359
-
360
- return self.plot(), self.num_steps_trained, self.train_loss
361
-
362
- def update_plot_options(self, **kwargs):
363
- for key, value in kwargs.items():
364
- if key in self.plot_options:
365
- self.plot_options[key] = value
366
- return self.plot()
367
-
368
- def update_architecture(self, hidden_layers, activations):
369
- self.architecture_options["hidden_layers"] = hidden_layers
370
- self.architecture_options["activations"] = activations
371
-
372
- # reset model
373
- self.model, self.optimizer, self.train_loss = self.init_model()
374
-
375
- return self.plot(), self.num_steps_trained, self.train_loss
376
-
377
- def update_basic_train_hparams(self, **kwargs):
378
- for key, value in kwargs.items():
379
- if key in self.basic_train_hparams:
380
- self.basic_train_hparams[key] = value
381
-
382
- # reset model
383
- self.model, self.optimizer, self.train_loss = self.init_model()
384
-
385
- return self.plot(), self.num_steps_trained, self.train_loss
386
-
387
- def update_optimizer(self, optimizer_name):
388
- self.basic_train_hparams["optimizer"] = optimizer_name
389
- # reset optimizer hyperparameters to default
390
- self.optimizer_hparams[optimizer_name] = self.DEFAULT_OPTIMIZER_HPARAMS[optimizer_name].copy()
391
-
392
- updates = []
393
- for opt_name, params in self.param_components.items():
394
- is_visible = (opt_name == optimizer_name)
395
- for _ in params.values():
396
- updates.append(gr.update(visible=is_visible))
397
-
398
- # reset model
399
- self.model, self.optimizer, self.train_loss = self.init_model()
400
-
401
- return updates + [self.plot(), self.num_steps_trained, self.train_loss]
402
-
403
- def build_optimizer_components(self):
404
- self.param_components = {}
405
- for opt_name, params in self.DEFAULT_OPTIMIZER_HPARAMS.items():
406
- opt_dict = {}
407
- for param_name, param_value in params.items():
408
- opt_dict[param_name] = gr.Number(
409
- label=f"{param_name}",
410
- value=param_value,
411
- visible=(opt_name == self.DEFAULT_OPTIMIZER),
412
- interactive=True,
413
- )
414
- self.param_components[opt_name] = opt_dict
415
-
416
- all_param_components = [
417
- comp for opt in self.param_components.values() for comp in opt.values()
418
- ]
419
- return all_param_components
420
-
421
- def update_hparam(self, value, optimizer_name, param_name):
422
- self.optimizer_hparams[optimizer_name][param_name] = value
423
-
424
- # reset model and plot
425
- self.model, self.optimizer, self.train_loss = self.init_model()
426
- return self.plot(), self.num_steps_trained, self.train_loss
427
-
428
- def train_step(self):
429
- self.model.train()
430
-
431
- inputs = torch.from_numpy(self.x_train).float()
432
- targets = torch.from_numpy(self.y_train).float().unsqueeze(1)
433
- outputs = self.model(inputs)
434
- loss = self.criterion(outputs, targets)
435
-
436
- self.optimizer.zero_grad()
437
- loss.backward()
438
- self.optimizer.step()
439
-
440
- self.num_steps_trained += 1
441
-
442
- # update train loss
443
- self.model.eval()
444
- with torch.no_grad():
445
- outputs = self.model(inputs)
446
- self.train_loss = self.criterion(outputs, targets).item()
447
-
448
- return self.plot(), self.num_steps_trained, self.train_loss
449
-
450
  def launch(self):
451
  # build the Gradio interface
452
  with gr.Blocks(css=self.css) as demo:
453
  # app title
454
  gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>MLP Training Visualizer</div>")
455
 
 
 
 
456
  # GUI elements and layout
457
  with gr.Row():
458
  with gr.Column(scale=2):
459
- self.canvas = gr.Image(
460
- value=self.plot(),
461
  show_download_button=False,
462
  container=True,
463
  )
464
 
465
  with gr.Column(scale=1):
466
  with gr.Tab("Dataset"):
467
- dataset_radio = gr.Radio(
468
- ["Generate", "Upload"],
469
- value="Generate",
470
- label="Dataset",
 
 
471
  )
472
 
473
- with gr.Column():
474
- function_box = gr.Textbox(
475
- label="Function",
476
- placeholder="function of x",
477
- value=self.DEFAULT_FUNCTION,
478
- interactive=True,
479
- )
480
- with gr.Row():
481
- x_min = gr.Number(
482
- label="Min x",
483
- value=-1,
484
- interactive=True,
485
- )
486
- x_max = gr.Number(
487
- label="Max x",
488
- value=1,
489
- interactive=True,
490
- )
491
- with gr.Row():
492
- noise_value = gr.Number(
493
- label="Gaussian noise standard deviation",
494
- value=0,
495
- interactive=True,
496
- )
497
- num_points_slider = gr.Slider(
498
- label="Number of data points",
499
- minimum=0,
500
- maximum=100,
501
- step=1,
502
- value=30,
503
- interactive=True,
504
- )
505
-
506
- regenerate_button = gr.Button("Regenerate Data")
507
-
508
- # upload data
509
- file_chooser = gr.File(label="Choose a file", visible=False, elem_id="rowheight")
510
- self.file_chooser = file_chooser
511
-
512
  with gr.Tab("Architecture"):
513
- self.architecture_component = ArchitectureComponent(self.update_architecture, self.canvas)
514
- self.architecture_component.build()
515
-
516
  with gr.Tab("Train"):
517
- optimizer_radio = gr.Radio(
518
- ["SGD", "Adam"],
519
- value=self.DEFAULT_OPTIMIZER,
520
- label="Optimizer",
521
- )
522
-
523
- all_param_components = self.build_optimizer_components()
524
- self.temp = all_param_components
525
-
526
- batch_size_slider = gr.Slider(
527
- label="Batch Size",
528
- minimum=1,
529
- maximum=self.x_train.shape[0],
530
- step=1,
531
- value=self.x_train.shape[0],
532
- interactive=True,
533
- )
534
-
535
- with gr.Row():
536
- train_step_counter = gr.Number(
537
- label="Train steps",
538
- value=0,
539
- interactive=False,
540
- )
541
- train_loss_display = gr.Number(
542
- label="Train loss",
543
- value=self.train_loss,
544
- interactive=False,
545
- )
546
-
547
- train_button = gr.Button("Train Step")
548
- reset_model_button = gr.Button("Reset Model")
549
-
550
  with gr.Tab("Plot"):
551
- # plot show options
552
- with gr.Column():
553
- with gr.Row():
554
- show_training_data = gr.Checkbox(label="Show training data", value=True)
555
- show_true_function = gr.Checkbox(label="Show true function", value=True)
556
- with gr.Row():
557
- show_predictions = gr.Checkbox(label="Show mean prediction", value=True)
558
-
559
- #gr.Markdown(''.join(open('kernel_examples.md', 'r').readlines()))
560
-
561
  with gr.Tab("Export"):
562
- # use hidden download button to generate files on the fly
563
- # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
564
-
565
- btn_export_data = gr.Button("Data")
566
- btn_export_data_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_data_hidden", elem_classes="hidden-button")
567
-
568
- btn_export_model = gr.Button('Model')
569
- btn_export_model_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_model_hidden", elem_classes="hidden-button")
570
-
571
- btn_export_code = gr.Button('Code')
572
- btn_export_code_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_code_hidden", elem_classes="hidden-button")
573
-
574
  with gr.Tab("Usage"):
575
- gr.Markdown(''.join(open('usage.md', 'r').readlines()))
576
-
577
- # data options
578
- function_box.submit(
579
- fn=lambda function: self.update_data_options(function=function),
580
- inputs=function_box,
581
- outputs=[self.canvas, train_step_counter, train_loss_display],
582
- )
583
- x_min.submit(
584
- fn=lambda xmin: self.update_data_options(x_min=xmin),
585
- inputs=x_min,
586
- outputs=[self.canvas, train_step_counter, train_loss_display],
587
- )
588
- x_max.submit(
589
- fn=lambda xmax: self.update_data_options(x_max=xmax),
590
- inputs=x_max,
591
- outputs=[self.canvas, train_step_counter, train_loss_display],
592
- )
593
- num_points_slider.change(
594
- fn=lambda nsample: self.update_data_options(nsample=nsample),
595
- inputs=num_points_slider,
596
- outputs=[self.canvas, batch_size_slider, train_step_counter, train_loss_display],
597
- )
598
- noise_value.submit(
599
- fn=lambda sigma: self.update_data_options(sigma=sigma),
600
- inputs=noise_value,
601
- outputs=[self.canvas, train_step_counter, train_loss_display],
602
- )
603
- regenerate_button.click(
604
- fn=self._update_data_seed,
605
- outputs=[self.canvas, train_step_counter, train_loss_display],
606
- )
607
-
608
- # train options
609
- optimizer_radio.change(
610
- fn=self.update_optimizer,
611
- inputs=optimizer_radio,
612
- outputs=[*all_param_components, self.canvas, train_step_counter, train_loss_display],
613
- )
614
- batch_size_slider.change(
615
- fn=lambda batch_size: self.update_basic_train_hparams(batch_size=batch_size),
616
- inputs=batch_size_slider,
617
- outputs=[self.canvas, train_step_counter, train_loss_display],
618
- )
619
- train_button.click(
620
- fn=self.train_step,
621
- outputs=[self.canvas, train_step_counter, train_loss_display],
622
- show_progress="hidden",
623
- )
624
- reset_model_button.click(
625
- fn=self.reset_model,
626
- outputs=[self.canvas, train_step_counter, train_loss_display],
627
- )
628
- for opt_name, params in self.param_components.items():
629
- for param_name, comp in params.items():
630
- comp.submit(
631
- fn=functools.partial(self.update_hparam, optimizer_name=opt_name, param_name=param_name),
632
- inputs=[comp],
633
- outputs=[self.canvas, train_step_counter, train_loss_display],
634
- )
635
-
636
- # plot options
637
- show_training_data.change(
638
- fn=lambda show: self.update_plot_options(show_training_data=show),
639
- inputs=show_training_data,
640
- outputs=[self.canvas],
641
- show_progress="hidden",
642
- )
643
- show_true_function.change(
644
- fn=lambda show: self.update_plot_options(show_true_function=show),
645
- inputs=show_true_function,
646
- outputs=[self.canvas],
647
- show_progress="hidden",
648
- )
649
- show_predictions.change(
650
- fn=lambda show: self.update_plot_options(show_predictions=show),
651
- inputs=show_predictions,
652
- outputs=[self.canvas],
653
- show_progress="hidden",
654
- )
655
-
656
- demo.load(self.on_load)
657
 
658
  demo.launch()
659
 
 
1
  from collections import deque
2
+ from dataclasses import dataclass, replace
3
  import functools
4
  from pathlib import Path
5
  import pickle
 
29
  )
30
  logger = logging.getLogger("ELVIS")
31
 
32
+ from dataset_options import DatasetOptions, DatasetOptionsView, get_function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
  class MlpVisualizer:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def __init__(self, width, height):
37
  self.canvas_width = width
38
  self.canvas_height = height
39
 
 
 
40
  self.plot_cmap = plt.get_cmap("tab20")
41
 
42
  self.css = """
 
44
  display: none;
45
  }"""
46
 
47
+ def plot(self, dataset_options: DatasetOptions):
48
+ print("Plotting")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  t1 = time.time()
50
+ fig = plt.figure(figsize=(self.canvas_width / 100., self.canvas_height / 100.0), dpi=100)
 
51
  # set entire figure to be the canvas to allow simple conversion of mouse
52
  # position to coordinates in the figure
53
  ax = fig.add_axes([0., 0., 1., 1.]) #
54
  ax.margins(x=0, y=0) # no padding in both directions
55
 
56
+ if dataset_options.mode == "generate":
57
+ x_test, y_test = get_function(dataset_options.function, xlim=(-2, 2), nsample=100)
58
+
59
+ # y_pred = self.model(torch.from_numpy(x_test).float()).detach().numpy()
60
 
61
  # plot
62
  fig, ax = plt.subplots(figsize=(8, 8))
63
  ax.set_title("")
64
  ax.set_xlabel("x")
65
  ax.set_ylabel("y")
 
66
 
67
+ if dataset_options.mode == "generate":
68
+ ax.set_ylim(y_test.min() - 1, y_test.max() + 1)
69
+
70
+ x_train = dataset_options.x
71
+ y_train = dataset_options.y
72
+ if True:
73
+ plt.scatter(x_train.flatten(), y_train, label='training data', color=self.plot_cmap(0))
74
 
75
+ if dataset_options.mode == "generate":
76
  plt.plot(x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1))
77
 
78
+ if False:
79
  plt.plot(x_test.flatten(), y_pred, linestyle="--", label='prediction', color=self.plot_cmap(2))
80
 
81
  plt.legend()
 
91
 
92
  return img
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def launch(self):
95
  # build the Gradio interface
96
  with gr.Blocks(css=self.css) as demo:
97
  # app title
98
  gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>MLP Training Visualizer</div>")
99
 
100
+ # states
101
+ dataset_options = gr.State(DatasetOptions())
102
+
103
  # GUI elements and layout
104
  with gr.Row():
105
  with gr.Column(scale=2):
106
+ canvas = gr.Image(
107
+ value=self.plot(dataset_options.value),
108
  show_download_button=False,
109
  container=True,
110
  )
111
 
112
  with gr.Column(scale=1):
113
  with gr.Tab("Dataset"):
114
+ dataset_view = DatasetOptionsView()
115
+ dataset_view.build(state=dataset_options)
116
+ dataset_options.change(
117
+ fn=self.plot,
118
+ inputs=[dataset_options],
119
+ outputs=[canvas],
120
  )
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  with gr.Tab("Architecture"):
123
+ gr.Markdown("HI")
 
 
124
  with gr.Tab("Train"):
125
+ gr.Markdown("HI")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  with gr.Tab("Plot"):
127
+ gr.Markdown("HI")
 
 
 
 
 
 
 
 
 
128
  with gr.Tab("Export"):
129
+ gr.Markdown("HI")
 
 
 
 
 
 
 
 
 
 
 
130
  with gr.Tab("Usage"):
131
+ gr.Markdown("HI")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  demo.launch()
134
 
mlp_visualizer_old.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ import functools
3
+ from pathlib import Path
4
+ import pickle
5
+ import time
6
+
7
+ import gradio as gr
8
+ import inspect
9
+ import io
10
+ from jinja2 import Template
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib.lines as mlines
13
+ import numpy as np
14
+ import numexpr
15
+ import pandas as pd
16
+ from PIL import Image
17
+ import plotly.graph_objects as go
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ import traceback
22
+ import yaml
23
+
24
+ import logging
25
+ logging.basicConfig(
26
+ level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL)
27
+ format="%(asctime)s [%(levelname)s] %(message)s", # log format
28
+ )
29
+ logger = logging.getLogger("ELVIS")
30
+
31
+
32
+ NUMEXPR_CONSTANTS = {
33
+ 'pi': np.pi,
34
+ 'PI': np.pi,
35
+ 'e': np.e,
36
+ }
37
+
38
+
39
+ def get_function(function, xlim=(-1, 1), nsample=100):
40
+ x = np.linspace(xlim[0], xlim[1], nsample)
41
+ y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS})
42
+ x = x.reshape(-1, 1)
43
+ return x, y
44
+
45
+
46
+ def get_data_points(function, xlim=(-1, 1), nsample=10, sigma=0, seed=0):
47
+ num_points_to_generate = 100
48
+ if nsample > num_points_to_generate:
49
+ raise ValueError(f"nsample too large, limit to {num_points_to_generate}")
50
+
51
+ rng = np.random.default_rng(seed)
52
+ x = rng.uniform(xlim[0], xlim[1], size=num_points_to_generate)
53
+ x = x[:nsample]
54
+ x = np.sort(x)
55
+
56
+ rng = np.random.default_rng(seed)
57
+ noise = sigma * rng.standard_normal(nsample)
58
+ y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS}) + noise
59
+
60
+ x = x.reshape(-1, 1)
61
+ return x, y
62
+
63
+
64
+ class HiddenLayerBox:
65
+ def __init__(self, initially_visible=False):
66
+ with gr.Row():
67
+ self.hidden_units = gr.Number(label="Hidden units", value=64, visible=initially_visible)
68
+ self.activation = gr.Textbox(label="Activation", value="ReLU", visible=initially_visible)
69
+
70
+ def set_visibility(self, visible):
71
+ return [
72
+ gr.update(visible=visible),
73
+ gr.update(visible=visible),
74
+ ]
75
+
76
+ def get_values(self):
77
+ return [self.hidden_units, self.activation]
78
+
79
+
80
+ class ArchitectureComponent:
81
+ def __init__(self, update_architecture_callback, canvas, max_layers=5):
82
+ self.num_show = 2
83
+ self.components = []
84
+ for i in range(max_layers):
85
+ comp = HiddenLayerBox(initially_visible=(i < self.num_show))
86
+ self.components.append(comp)
87
+
88
+ self.update_architecture_callback = update_architecture_callback
89
+ self.canvas = canvas
90
+
91
+ def update_architecture(self, *values):
92
+ # values come as [hidden1, act1, hidden2, act2, ...]
93
+ hidden_layers = []
94
+ activations = []
95
+ for i in range(0, self.num_show * 2, 2):
96
+ if values[i] != "" or values[i + 1] != "":
97
+ hidden_layers.append(values[i])
98
+ activations.append(values[i + 1])
99
+ return self.update_architecture_callback(hidden_layers, activations)
100
+
101
+ def build(self):
102
+ with gr.Row():
103
+ add_btn = gr.Button("Add layer")
104
+ remove_btn = gr.Button("Remove layer")
105
+
106
+ with gr.Row():
107
+ gr.Number(label="Output units", value=1, interactive=False)
108
+ gr.Textbox(label="Activation", value="Identity", interactive=False)
109
+
110
+ # Collect all subcomponents
111
+ all_outputs = []
112
+ for comp in self.components:
113
+ all_outputs += [comp.hidden_units, comp.activation]
114
+
115
+ def on_add():
116
+ self.num_show = min(self.num_show + 1, len(self.components))
117
+ updates = []
118
+ for i, comp in enumerate(self.components):
119
+ updates += comp.set_visibility(i < self.num_show)
120
+
121
+ updates += [gr.update(value=self.num_show)]
122
+ return updates
123
+
124
+ def on_remove():
125
+ self.num_show = max(self.num_show - 1, 0)
126
+ updates = []
127
+ for i, comp in enumerate(self.components):
128
+ updates += comp.set_visibility(i < self.num_show)
129
+
130
+ updates += [gr.update(value=self.num_show)]
131
+ return updates
132
+
133
+ hidden_counter = gr.Number(value=self.num_show, visible=False)
134
+
135
+ add_btn.click(on_add, outputs=[*all_outputs, hidden_counter] )
136
+ remove_btn.click(on_remove, outputs=[*all_outputs, hidden_counter] )
137
+
138
+ for output in all_outputs:
139
+ output.submit(
140
+ fn=self.update_architecture,
141
+ inputs=all_outputs,
142
+ outputs=[self.canvas],
143
+ )
144
+ hidden_counter.change(
145
+ fn=self.update_architecture,
146
+ inputs=all_outputs,
147
+ outputs=[self.canvas],
148
+ )
149
+
150
+
151
+ class MlpVisualizer:
152
+ DEFAULT_FUNCTION = "sin(2 * pi * x)"
153
+
154
+ DEFAULT_OPTIMIZER = "SGD"
155
+ DEFAULT_LEARNING_RATE = 0.01
156
+
157
+ DEFAULT_OPTIMIZER_HPARAMS = {
158
+ "SGD": {
159
+ "learning_rate": 0.1,
160
+ "momentum": 0.0,
161
+ },
162
+ "Adam": {
163
+ "learning_rate": 0.01,
164
+ "beta1": 0.9,
165
+ "beta2": 0.999,
166
+ "eps": 1e-8,
167
+ },
168
+ }
169
+
170
+ def _init_state(self):
171
+ self.data_options = {
172
+ "function": self.DEFAULT_FUNCTION,
173
+ "nsample": 30,
174
+ "sigma": 0,
175
+ "seed": 0,
176
+ "x_min": -1,
177
+ "x_max": 1,
178
+ }
179
+ self.x_train, self.y_train = self.generate_data()
180
+
181
+ self.architecture_options = {
182
+ "hidden_layers": [64, 64],
183
+ "activations": ["ReLU", "ReLU"],
184
+ }
185
+ self.basic_train_hparams = {
186
+ "batch_size": self.x_train.shape[0],
187
+ "optimizer": self.DEFAULT_OPTIMIZER,
188
+ }
189
+
190
+ # important to copy dict
191
+ self.optimizer_hparams = {}
192
+ for opt, params in self.DEFAULT_OPTIMIZER_HPARAMS.items():
193
+ self.optimizer_hparams[opt] = params.copy()
194
+
195
+ # do not initialise here, otherwise gradio will make it not work
196
+ # self.param_components = {}
197
+
198
+ self.criterion = nn.MSELoss()
199
+ self.model, self.optimizer, self.train_loss = self.init_model()
200
+ self.num_steps_trained = 0
201
+
202
+
203
+ self.plot_options = {
204
+ "show_training_data": True,
205
+ "show_true_function": True,
206
+ "show_predictions": True,
207
+ }
208
+
209
+ def __init__(self, width, height):
210
+ self.canvas_width = width
211
+ self.canvas_height = height
212
+
213
+ self._init_state()
214
+
215
+ self.plot_cmap = plt.get_cmap("tab20")
216
+
217
+ self.css = """
218
+ .hidden-button {
219
+ display: none;
220
+ }"""
221
+
222
+ def on_load(self):
223
+ self._init_state()
224
+
225
+ def generate_data(self):
226
+ function = self.data_options["function"]
227
+ nsample = self.data_options["nsample"]
228
+ sigma = self.data_options["sigma"]
229
+ x_min = self.data_options["x_min"]
230
+ x_max = self.data_options["x_max"]
231
+
232
+ return get_data_points(function, xlim=(x_min, x_max), nsample=nsample, sigma=sigma, seed=self.data_options["seed"])
233
+
234
+ def init_model(self):
235
+ print(self.architecture_options)
236
+ layers = []
237
+ input_size = 1
238
+ for hidden_units, activation in zip(self.architecture_options["hidden_layers"], self.architecture_options["activations"]):
239
+ layers.append(nn.Linear(input_size, hidden_units))
240
+ if activation == "ReLU":
241
+ layers.append(nn.ReLU())
242
+ elif activation == "Sigmoid":
243
+ layers.append(nn.Sigmoid())
244
+ elif activation == "Tanh":
245
+ layers.append(nn.Tanh())
246
+ elif activation == "LeakyReLU":
247
+ layers.append(nn.LeakyReLU())
248
+ elif activation == "Identity":
249
+ layers.append(nn.Identity())
250
+ else:
251
+ raise ValueError(f"Unsupported activation: {activation}")
252
+ input_size = hidden_units
253
+
254
+ output_layer = nn.Linear(input_size, 1)
255
+ model = nn.Sequential(*layers, output_layer)
256
+
257
+ if self.basic_train_hparams["optimizer"] == "Adam":
258
+ optimizer = torch.optim.Adam(
259
+ model.parameters(),
260
+ lr=self.optimizer_hparams["Adam"]["learning_rate"],
261
+ betas=(self.optimizer_hparams["Adam"]["beta1"], self.optimizer_hparams["Adam"]["beta2"]),
262
+ eps=self.optimizer_hparams["Adam"]["eps"],
263
+ )
264
+ elif self.basic_train_hparams["optimizer"] == "SGD":
265
+ optimizer = torch.optim.SGD(
266
+ model.parameters(),
267
+ lr=self.optimizer_hparams["SGD"]["learning_rate"],
268
+ momentum=self.optimizer_hparams["SGD"]["momentum"],
269
+ )
270
+ else:
271
+ raise ValueError(f"Unsupported optimizer: {self.basic_train_hparams['optimizer']}")
272
+
273
+ self.num_steps_trained = 0
274
+
275
+ # compute initial train loss
276
+ model.eval()
277
+ inputs = torch.from_numpy(self.x_train).float()
278
+ targets = torch.from_numpy(self.y_train).float().unsqueeze(1)
279
+ with torch.no_grad():
280
+ outputs = model(inputs)
281
+ train_loss = self.criterion(outputs, targets).item()
282
+
283
+ return model, optimizer, train_loss
284
+
285
+ def plot(self):
286
+ '''
287
+ '''
288
+ t1 = time.time()
289
+ logger.info("Initializing figure")
290
+ fig = plt.figure(figsize=(self.canvas_width/100., self.canvas_height/100.0), dpi=100)
291
+ # set entire figure to be the canvas to allow simple conversion of mouse
292
+ # position to coordinates in the figure
293
+ ax = fig.add_axes([0., 0., 1., 1.]) #
294
+ ax.margins(x=0, y=0) # no padding in both directions
295
+
296
+ x_test, y_test = get_function(self.data_options["function"], xlim=(-2, 2), nsample=100)
297
+ y_pred = self.model(torch.from_numpy(x_test).float()).detach().numpy()
298
+
299
+ # plot
300
+ fig, ax = plt.subplots(figsize=(8, 8))
301
+ ax.set_title("")
302
+ ax.set_xlabel("x")
303
+ ax.set_ylabel("y")
304
+ ax.set_ylim(y_test.min() - 1, y_test.max() + 1)
305
+
306
+ if self.plot_options["show_training_data"]:
307
+ plt.scatter(self.x_train.flatten(), self.y_train, label='training data', color=self.plot_cmap(0))
308
+
309
+ if self.plot_options["show_true_function"]:
310
+ plt.plot(x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1))
311
+
312
+ if self.plot_options["show_predictions"]:
313
+ plt.plot(x_test.flatten(), y_pred, linestyle="--", label='prediction', color=self.plot_cmap(2))
314
+
315
+ plt.legend()
316
+
317
+ buf = io.BytesIO()
318
+ fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
319
+ plt.close(fig)
320
+ buf.seek(0)
321
+ img = Image.open(buf)
322
+
323
+ t2 = time.time()
324
+ logger.info(f"Plotting took {t2 - t1:.4f} seconds")
325
+
326
+ return img
327
+
328
+ def _update_data_seed(self):
329
+ self.data_options["seed"] += 1
330
+ self.x_train, self.y_train = self.generate_data()
331
+ self.reset_model()
332
+ return self.plot(), self.num_steps_trained, self.train_loss
333
+
334
+ def reset_model(self):
335
+ self.model, self.optimizer, self.train_loss = self.init_model()
336
+ return self.plot(), self.num_steps_trained, self.train_loss
337
+
338
+ def update_data_options(self, **kwargs):
339
+ for key, value in kwargs.items():
340
+ if key in self.data_options:
341
+
342
+ # if function - test if valid
343
+ if key == "function":
344
+ try:
345
+ x = np.linspace(-1, 1, 10)
346
+ y = numexpr.evaluate(value, local_dict={'x': x, **NUMEXPR_CONSTANTS})
347
+ except Exception as e:
348
+ raise ValueError(f"Invalid function: {e}")
349
+
350
+ self.data_options[key] = value
351
+
352
+ # reset data and model
353
+ self.x_train, self.y_train = self.generate_data()
354
+ self.reset_model()
355
+
356
+ if "nsample" in kwargs:
357
+ slider_update = gr.update(maximum=self.x_train.shape[0], value=min(self.basic_train_hparams["batch_size"], self.x_train.shape[0]))
358
+ return self.plot(), slider_update, self.num_steps_trained, self.train_loss
359
+
360
+ return self.plot(), self.num_steps_trained, self.train_loss
361
+
362
+ def update_plot_options(self, **kwargs):
363
+ for key, value in kwargs.items():
364
+ if key in self.plot_options:
365
+ self.plot_options[key] = value
366
+ return self.plot()
367
+
368
+ def update_architecture(self, hidden_layers, activations):
369
+ self.architecture_options["hidden_layers"] = hidden_layers
370
+ self.architecture_options["activations"] = activations
371
+
372
+ # reset model
373
+ self.model, self.optimizer, self.train_loss = self.init_model()
374
+
375
+ return self.plot(), self.num_steps_trained, self.train_loss
376
+
377
+ def update_basic_train_hparams(self, **kwargs):
378
+ for key, value in kwargs.items():
379
+ if key in self.basic_train_hparams:
380
+ self.basic_train_hparams[key] = value
381
+
382
+ # reset model
383
+ self.model, self.optimizer, self.train_loss = self.init_model()
384
+
385
+ return self.plot(), self.num_steps_trained, self.train_loss
386
+
387
+ def update_optimizer(self, optimizer_name):
388
+ self.basic_train_hparams["optimizer"] = optimizer_name
389
+ # reset optimizer hyperparameters to default
390
+ self.optimizer_hparams[optimizer_name] = self.DEFAULT_OPTIMIZER_HPARAMS[optimizer_name].copy()
391
+
392
+ updates = []
393
+ for opt_name, params in self.param_components.items():
394
+ is_visible = (opt_name == optimizer_name)
395
+ for _ in params.values():
396
+ updates.append(gr.update(visible=is_visible))
397
+
398
+ # reset model
399
+ self.model, self.optimizer, self.train_loss = self.init_model()
400
+
401
+ return updates + [self.plot(), self.num_steps_trained, self.train_loss]
402
+
403
+ def build_optimizer_components(self):
404
+ self.param_components = {}
405
+ for opt_name, params in self.DEFAULT_OPTIMIZER_HPARAMS.items():
406
+ opt_dict = {}
407
+ for param_name, param_value in params.items():
408
+ opt_dict[param_name] = gr.Number(
409
+ label=f"{param_name}",
410
+ value=param_value,
411
+ visible=(opt_name == self.DEFAULT_OPTIMIZER),
412
+ interactive=True,
413
+ )
414
+ self.param_components[opt_name] = opt_dict
415
+
416
+ all_param_components = [
417
+ comp for opt in self.param_components.values() for comp in opt.values()
418
+ ]
419
+ return all_param_components
420
+
421
+ def update_hparam(self, value, optimizer_name, param_name):
422
+ self.optimizer_hparams[optimizer_name][param_name] = value
423
+
424
+ # reset model and plot
425
+ self.model, self.optimizer, self.train_loss = self.init_model()
426
+ return self.plot(), self.num_steps_trained, self.train_loss
427
+
428
+ def train_step(self):
429
+ self.model.train()
430
+
431
+ inputs = torch.from_numpy(self.x_train).float()
432
+ targets = torch.from_numpy(self.y_train).float().unsqueeze(1)
433
+ outputs = self.model(inputs)
434
+ loss = self.criterion(outputs, targets)
435
+
436
+ self.optimizer.zero_grad()
437
+ loss.backward()
438
+ self.optimizer.step()
439
+
440
+ self.num_steps_trained += 1
441
+
442
+ # update train loss
443
+ self.model.eval()
444
+ with torch.no_grad():
445
+ outputs = self.model(inputs)
446
+ self.train_loss = self.criterion(outputs, targets).item()
447
+
448
+ return self.plot(), self.num_steps_trained, self.train_loss
449
+
450
+ def launch(self):
451
+ # build the Gradio interface
452
+ with gr.Blocks(css=self.css) as demo:
453
+ # app title
454
+ gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>MLP Training Visualizer</div>")
455
+
456
+ # GUI elements and layout
457
+ with gr.Row():
458
+ with gr.Column(scale=2):
459
+ self.canvas = gr.Image(
460
+ value=self.plot(),
461
+ show_download_button=False,
462
+ container=True,
463
+ )
464
+
465
+ with gr.Column(scale=1):
466
+ with gr.Tab("Dataset"):
467
+ dataset_radio = gr.Radio(
468
+ ["Generate", "Upload"],
469
+ value="Generate",
470
+ label="Dataset",
471
+ )
472
+
473
+ with gr.Column():
474
+ function_box = gr.Textbox(
475
+ label="Function",
476
+ placeholder="function of x",
477
+ value=self.DEFAULT_FUNCTION,
478
+ interactive=True,
479
+ )
480
+ with gr.Row():
481
+ x_min = gr.Number(
482
+ label="Min x",
483
+ value=-1,
484
+ interactive=True,
485
+ )
486
+ x_max = gr.Number(
487
+ label="Max x",
488
+ value=1,
489
+ interactive=True,
490
+ )
491
+ with gr.Row():
492
+ noise_value = gr.Number(
493
+ label="Gaussian noise standard deviation",
494
+ value=0,
495
+ interactive=True,
496
+ )
497
+ num_points_slider = gr.Slider(
498
+ label="Number of data points",
499
+ minimum=0,
500
+ maximum=100,
501
+ step=1,
502
+ value=30,
503
+ interactive=True,
504
+ )
505
+
506
+ regenerate_button = gr.Button("Regenerate Data")
507
+
508
+ # upload data
509
+ file_chooser = gr.File(label="Choose a file", visible=False, elem_id="rowheight")
510
+ self.file_chooser = file_chooser
511
+
512
+ with gr.Tab("Architecture"):
513
+ self.architecture_component = ArchitectureComponent(self.update_architecture, self.canvas)
514
+ self.architecture_component.build()
515
+
516
+ with gr.Tab("Train"):
517
+ optimizer_radio = gr.Radio(
518
+ ["SGD", "Adam"],
519
+ value=self.DEFAULT_OPTIMIZER,
520
+ label="Optimizer",
521
+ )
522
+
523
+ all_param_components = self.build_optimizer_components()
524
+ self.temp = all_param_components
525
+
526
+ batch_size_slider = gr.Slider(
527
+ label="Batch Size",
528
+ minimum=1,
529
+ maximum=self.x_train.shape[0],
530
+ step=1,
531
+ value=self.x_train.shape[0],
532
+ interactive=True,
533
+ )
534
+
535
+ with gr.Row():
536
+ train_step_counter = gr.Number(
537
+ label="Train steps",
538
+ value=0,
539
+ interactive=False,
540
+ )
541
+ train_loss_display = gr.Number(
542
+ label="Train loss",
543
+ value=self.train_loss,
544
+ interactive=False,
545
+ )
546
+
547
+ train_button = gr.Button("Train Step")
548
+ reset_model_button = gr.Button("Reset Model")
549
+
550
+ with gr.Tab("Plot"):
551
+ # plot show options
552
+ with gr.Column():
553
+ with gr.Row():
554
+ show_training_data = gr.Checkbox(label="Show training data", value=True)
555
+ show_true_function = gr.Checkbox(label="Show true function", value=True)
556
+ with gr.Row():
557
+ show_predictions = gr.Checkbox(label="Show mean prediction", value=True)
558
+
559
+ #gr.Markdown(''.join(open('kernel_examples.md', 'r').readlines()))
560
+
561
+ with gr.Tab("Export"):
562
+ # use hidden download button to generate files on the fly
563
+ # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
564
+
565
+ btn_export_data = gr.Button("Data")
566
+ btn_export_data_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_data_hidden", elem_classes="hidden-button")
567
+
568
+ btn_export_model = gr.Button('Model')
569
+ btn_export_model_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_model_hidden", elem_classes="hidden-button")
570
+
571
+ btn_export_code = gr.Button('Code')
572
+ btn_export_code_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_code_hidden", elem_classes="hidden-button")
573
+
574
+ with gr.Tab("Usage"):
575
+ gr.Markdown(''.join(open('usage.md', 'r').readlines()))
576
+
577
+ # data options
578
+ function_box.submit(
579
+ fn=lambda function: self.update_data_options(function=function),
580
+ inputs=function_box,
581
+ outputs=[self.canvas, train_step_counter, train_loss_display],
582
+ )
583
+ x_min.submit(
584
+ fn=lambda xmin: self.update_data_options(x_min=xmin),
585
+ inputs=x_min,
586
+ outputs=[self.canvas, train_step_counter, train_loss_display],
587
+ )
588
+ x_max.submit(
589
+ fn=lambda xmax: self.update_data_options(x_max=xmax),
590
+ inputs=x_max,
591
+ outputs=[self.canvas, train_step_counter, train_loss_display],
592
+ )
593
+ num_points_slider.change(
594
+ fn=lambda nsample: self.update_data_options(nsample=nsample),
595
+ inputs=num_points_slider,
596
+ outputs=[self.canvas, batch_size_slider, train_step_counter, train_loss_display],
597
+ )
598
+ noise_value.submit(
599
+ fn=lambda sigma: self.update_data_options(sigma=sigma),
600
+ inputs=noise_value,
601
+ outputs=[self.canvas, train_step_counter, train_loss_display],
602
+ )
603
+ regenerate_button.click(
604
+ fn=self._update_data_seed,
605
+ outputs=[self.canvas, train_step_counter, train_loss_display],
606
+ )
607
+
608
+ # train options
609
+ optimizer_radio.change(
610
+ fn=self.update_optimizer,
611
+ inputs=optimizer_radio,
612
+ outputs=[*all_param_components, self.canvas, train_step_counter, train_loss_display],
613
+ )
614
+ batch_size_slider.change(
615
+ fn=lambda batch_size: self.update_basic_train_hparams(batch_size=batch_size),
616
+ inputs=batch_size_slider,
617
+ outputs=[self.canvas, train_step_counter, train_loss_display],
618
+ )
619
+ train_button.click(
620
+ fn=self.train_step,
621
+ outputs=[self.canvas, train_step_counter, train_loss_display],
622
+ show_progress="hidden",
623
+ )
624
+ reset_model_button.click(
625
+ fn=self.reset_model,
626
+ outputs=[self.canvas, train_step_counter, train_loss_display],
627
+ )
628
+ for opt_name, params in self.param_components.items():
629
+ for param_name, comp in params.items():
630
+ comp.submit(
631
+ fn=functools.partial(self.update_hparam, optimizer_name=opt_name, param_name=param_name),
632
+ inputs=[comp],
633
+ outputs=[self.canvas, train_step_counter, train_loss_display],
634
+ )
635
+
636
+ # plot options
637
+ show_training_data.change(
638
+ fn=lambda show: self.update_plot_options(show_training_data=show),
639
+ inputs=show_training_data,
640
+ outputs=[self.canvas],
641
+ show_progress="hidden",
642
+ )
643
+ show_true_function.change(
644
+ fn=lambda show: self.update_plot_options(show_true_function=show),
645
+ inputs=show_true_function,
646
+ outputs=[self.canvas],
647
+ show_progress="hidden",
648
+ )
649
+ show_predictions.change(
650
+ fn=lambda show: self.update_plot_options(show_predictions=show),
651
+ inputs=show_predictions,
652
+ outputs=[self.canvas],
653
+ show_progress="hidden",
654
+ )
655
+
656
+ demo.load(self.on_load)
657
+
658
+ demo.launch()
659
+
660
+ visualizer = MlpVisualizer(width=1200, height=900)
661
+ visualizer.launch()
662
+