joel-woodfield commited on
Commit
3892dda
·
1 Parent(s): 7b0039e

Initial commit

Browse files
Files changed (4) hide show
  1. README.md +6 -6
  2. mlp_visualizer.py +374 -0
  3. requirements.txt +8 -0
  4. usage.md +7 -0
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
- title: Mlp Visualizer
3
- emoji: 📚
4
- colorFrom: gray
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
1
  ---
2
+ title: MLP Visualizer
3
+ emoji: 🐨
4
+ colorFrom: yellow
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 5.46.0
8
+ app_file: mlp_visualizer.py
9
  pinned: false
10
  ---
11
 
mlp_visualizer.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from pathlib import Path
3
+ import pickle
4
+
5
+ import gradio as gr
6
+ import inspect
7
+ import io
8
+ from jinja2 import Template
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib.lines as mlines
11
+ import numpy as np
12
+ import numexpr
13
+ import pandas as pd
14
+ from PIL import Image
15
+ import plotly.graph_objects as go
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ import traceback
20
+ import yaml
21
+
22
+ import logging
23
+ logging.basicConfig(
24
+ level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL)
25
+ format="%(asctime)s [%(levelname)s] %(message)s", # log format
26
+ )
27
+ logger = logging.getLogger("ELVIS")
28
+
29
+
30
+ NUMEXPR_CONSTANTS = {
31
+ 'pi': np.pi,
32
+ 'PI': np.pi,
33
+ 'e': np.e,
34
+ }
35
+
36
+
37
+ def get_function(function, xlim=(-1, 1), nsample=100):
38
+ x = np.linspace(xlim[0], xlim[1], nsample)
39
+ y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS})
40
+ x = x.reshape(-1, 1)
41
+ return x, y
42
+
43
+
44
+ def get_data_points(function, xlim=(-1, 1), nsample=10, sigma=0, seed=0):
45
+ num_points_to_generate = 100
46
+ if nsample > num_points_to_generate:
47
+ raise ValueError(f"nsample too large, limit to {num_points_to_generate}")
48
+
49
+ rng = np.random.default_rng(seed)
50
+ x = rng.uniform(xlim[0], xlim[1], size=num_points_to_generate)
51
+ x = x[:nsample]
52
+ x = np.sort(x)
53
+
54
+ rng = np.random.default_rng(seed)
55
+ noise = sigma * rng.standard_normal(nsample)
56
+ y = numexpr.evaluate(function, local_dict={'x': x, **NUMEXPR_CONSTANTS}) + noise
57
+
58
+ x = x.reshape(-1, 1)
59
+ return x, y
60
+
61
+
62
+ class MlpVisualizer:
63
+ DEFAULT_FUNCTION = "sin(2 * pi * x)"
64
+
65
+ def _init_state(self):
66
+ self.data_options = {
67
+ "function": self.DEFAULT_FUNCTION,
68
+ "nsample": 30,
69
+ "sigma": 0,
70
+ "seed": 0,
71
+ "x_min": -1,
72
+ "x_max": 1,
73
+ }
74
+
75
+ self.x_train, self.y_train = self.generate_data()
76
+ self.model, self.optimizer = self.init_model()
77
+ self.criterion = nn.MSELoss()
78
+
79
+ self.plot_options = {
80
+ "show_training_data": True,
81
+ "show_true_function": True,
82
+ "show_predictions": True,
83
+ }
84
+
85
+ def __init__(self, width, height):
86
+ self.canvas_width = width
87
+ self.canvas_height = height
88
+
89
+ self._init_state()
90
+
91
+ self.plot_cmap = plt.get_cmap("tab20")
92
+
93
+ self.css = """
94
+ .hidden-button {
95
+ display: none;
96
+ }"""
97
+
98
+ def on_load(self):
99
+ self._init_state()
100
+
101
+ def generate_data(self):
102
+ function = self.data_options["function"]
103
+ nsample = self.data_options["nsample"]
104
+ sigma = self.data_options["sigma"]
105
+ x_min = self.data_options["x_min"]
106
+ x_max = self.data_options["x_max"]
107
+
108
+ return get_data_points(function, xlim=(x_min, x_max), nsample=nsample, sigma=sigma, seed=self.data_options["seed"])
109
+
110
+ def init_model(self):
111
+ model = nn.Sequential(
112
+ nn.Linear(1, 64),
113
+ nn.ReLU(),
114
+ nn.Linear(64, 64),
115
+ nn.ReLU(),
116
+ nn.Linear(64, 1),
117
+ )
118
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
119
+ return model, optimizer
120
+
121
+ def plot(self):
122
+ '''
123
+ '''
124
+
125
+ logger.info("Initializing figure")
126
+ fig = plt.figure(figsize=(self.canvas_width/100., self.canvas_height/100.0), dpi=100)
127
+ # set entire figure to be the canvas to allow simple conversion of mouse
128
+ # position to coordinates in the figure
129
+ ax = fig.add_axes([0., 0., 1., 1.]) #
130
+ ax.margins(x=0, y=0) # no padding in both directions
131
+
132
+ x_test, y_test = get_function(self.data_options["function"], xlim=(-2, 2), nsample=100)
133
+ y_pred = self.model(torch.from_numpy(x_test).float()).detach().numpy()
134
+
135
+ # plot
136
+ fig, ax = plt.subplots(figsize=(8, 8))
137
+ ax.set_title("")
138
+ ax.set_xlabel("x")
139
+ ax.set_ylabel("y")
140
+
141
+ if self.plot_options["show_training_data"]:
142
+ plt.scatter(self.x_train.flatten(), self.y_train, label='training data', color=self.plot_cmap(0))
143
+
144
+ if self.plot_options["show_true_function"]:
145
+ plt.plot(x_test.flatten(), y_test, label='true function', color=self.plot_cmap(1))
146
+
147
+ if self.plot_options["show_predictions"]:
148
+ plt.plot(x_test.flatten(), y_pred, linestyle="--", label='prediction', color=self.plot_cmap(2))
149
+
150
+ plt.legend()
151
+
152
+ buf = io.BytesIO()
153
+ fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
154
+ plt.close(fig)
155
+ buf.seek(0)
156
+ img = Image.open(buf)
157
+
158
+ return img
159
+
160
+ def _update_data_seed(self):
161
+ self.data_options["seed"] += 1
162
+ self.x_train, self.y_train = self.generate_data()
163
+ self.reset_model()
164
+ return self.plot()
165
+
166
+ def reset_model(self):
167
+ self.model, self.optimizer = self.init_model()
168
+ return self.plot()
169
+
170
+ def update_data_options(self, **kwargs):
171
+ for key, value in kwargs.items():
172
+ if key in self.data_options:
173
+
174
+ # if function - test if valid
175
+ if key == "function":
176
+ try:
177
+ x = np.linspace(-1, 1, 10)
178
+ y = numexpr.evaluate(value, local_dict={'x': x, **NUMEXPR_CONSTANTS})
179
+ except Exception as e:
180
+ raise ValueError(f"Invalid function: {e}")
181
+
182
+ self.data_options[key] = value
183
+
184
+ # reset data and model
185
+ self.x_train, self.y_train = self.generate_data()
186
+ self.reset_model()
187
+
188
+ return self.plot()
189
+
190
+ def update_plot_options(self, **kwargs):
191
+ for key, value in kwargs.items():
192
+ if key in self.plot_options:
193
+ self.plot_options[key] = value
194
+ return self.plot()
195
+
196
+ def train_step(self):
197
+ self.model.train()
198
+
199
+ inputs = torch.from_numpy(self.x_train).float()
200
+ targets = torch.from_numpy(self.y_train).float().unsqueeze(1)
201
+ outputs = self.model(inputs)
202
+ loss = self.criterion(outputs, targets)
203
+
204
+ self.optimizer.zero_grad()
205
+ loss.backward()
206
+ self.optimizer.step()
207
+
208
+ print(f"Training loss: {loss.item():.4f}")
209
+
210
+ return self.plot()
211
+
212
+ def launch(self):
213
+ # build the Gradio interface
214
+ with gr.Blocks(css=self.css) as demo:
215
+ # app title
216
+ gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>MLP Training Visualizer</div>")
217
+
218
+ # GUI elements and layout
219
+ with gr.Row():
220
+ with gr.Column(scale=2):
221
+ self.canvas = gr.Image(
222
+ value=self.plot(),
223
+ show_download_button=False,
224
+ container=True
225
+ )
226
+
227
+ with gr.Column(scale=1):
228
+ with gr.Tab("Dataset"):
229
+ dataset_radio = gr.Radio(
230
+ ["Generate", "Upload"],
231
+ value="Generate",
232
+ label="Dataset",
233
+ )
234
+
235
+ with gr.Column():
236
+ function_box = gr.Textbox(
237
+ label="Function",
238
+ placeholder="function of x",
239
+ value=self.DEFAULT_FUNCTION,
240
+ interactive=True,
241
+ )
242
+ with gr.Row():
243
+ x_min = gr.Number(
244
+ label="Min x",
245
+ value=-1,
246
+ interactive=True,
247
+ )
248
+ x_max = gr.Number(
249
+ label="Max x",
250
+ value=1,
251
+ interactive=True,
252
+ )
253
+ with gr.Row():
254
+ noise_value = gr.Number(
255
+ label="Gaussian noise standard deviation",
256
+ value=0,
257
+ interactive=True,
258
+ )
259
+ num_points_slider = gr.Slider(
260
+ label="Number of data points",
261
+ minimum=0,
262
+ maximum=100,
263
+ step=1,
264
+ value=30,
265
+ interactive=True,
266
+ )
267
+
268
+ regenerate_button = gr.Button("Regenerate Data")
269
+
270
+ # upload data
271
+ file_chooser = gr.File(label="Choose a file", visible=False, elem_id="rowheight")
272
+ self.file_chooser = file_chooser
273
+
274
+ with gr.Tab("Model"):
275
+ gr.Markdown("TODO")
276
+
277
+ with gr.Tab("Train"):
278
+ train_button = gr.Button("Train Step")
279
+ reset_model_button = gr.Button("Reset Model")
280
+
281
+ with gr.Tab("Plot"):
282
+ # plot show options
283
+ with gr.Column():
284
+ with gr.Row():
285
+ show_training_data = gr.Checkbox(label="Show training data", value=True)
286
+ show_true_function = gr.Checkbox(label="Show true function", value=True)
287
+ with gr.Row():
288
+ show_predictions = gr.Checkbox(label="Show mean prediction", value=True)
289
+
290
+ #gr.Markdown(''.join(open('kernel_examples.md', 'r').readlines()))
291
+
292
+ with gr.Tab("Export"):
293
+ # use hidden download button to generate files on the fly
294
+ # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
295
+
296
+ btn_export_data = gr.Button("Data")
297
+ btn_export_data_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_data_hidden", elem_classes="hidden-button")
298
+
299
+ btn_export_model = gr.Button('Model')
300
+ btn_export_model_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_model_hidden", elem_classes="hidden-button")
301
+
302
+ btn_export_code = gr.Button('Code')
303
+ btn_export_code_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_code_hidden", elem_classes="hidden-button")
304
+
305
+ with gr.Tab("Usage"):
306
+ gr.Markdown(''.join(open('usage.md', 'r').readlines()))
307
+
308
+ # data options
309
+ function_box.submit(
310
+ fn=lambda function: self.update_data_options(function=function),
311
+ inputs=function_box,
312
+ outputs=[self.canvas],
313
+ )
314
+ x_min.submit(
315
+ fn=lambda xmin: self.update_data_options(x_min=xmin),
316
+ inputs=x_min,
317
+ outputs=[self.canvas],
318
+ )
319
+ x_max.submit(
320
+ fn=lambda xmax: self.update_data_options(x_max=xmax),
321
+ inputs=x_max,
322
+ outputs=[self.canvas],
323
+ )
324
+ num_points_slider.change(
325
+ fn=lambda nsample: self.update_data_options(nsample=nsample),
326
+ inputs=num_points_slider,
327
+ outputs=[self.canvas],
328
+ )
329
+ noise_value.submit(
330
+ fn=lambda sigma: self.update_data_options(sigma=sigma),
331
+ inputs=noise_value,
332
+ outputs=[self.canvas],
333
+ )
334
+ regenerate_button.click(
335
+ fn=self._update_data_seed,
336
+ outputs=[self.canvas],
337
+ )
338
+
339
+ # model options
340
+
341
+ # train options
342
+ train_button.click(
343
+ fn=self.train_step,
344
+ outputs=[self.canvas],
345
+ )
346
+ reset_model_button.click(
347
+ fn=self.reset_model,
348
+ outputs=[self.canvas],
349
+ )
350
+
351
+ # plot options
352
+ show_training_data.change(
353
+ fn=lambda show: self.update_plot_options(show_training_data=show),
354
+ inputs=show_training_data,
355
+ outputs=[self.canvas],
356
+ )
357
+ show_true_function.change(
358
+ fn=lambda show: self.update_plot_options(show_true_function=show),
359
+ inputs=show_true_function,
360
+ outputs=[self.canvas],
361
+ )
362
+ show_predictions.change(
363
+ fn=lambda show: self.update_plot_options(show_predictions=show),
364
+ inputs=show_predictions,
365
+ outputs=[self.canvas],
366
+ )
367
+
368
+ demo.load(self.on_load)
369
+
370
+ demo.launch()
371
+
372
+ visualizer = MlpVisualizer(width=1200, height=900)
373
+ visualizer.launch()
374
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ matplotlib
2
+ numexpr
3
+ numpy
4
+ pandas
5
+ pillow
6
+ plotly
7
+ scikit-learn
8
+ torch
usage.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ **Quick start**
2
+
3
+ **Kernel examples**
4
+ * RBF()
5
+ * RBF(length_scale=1, length_scale_bounds="fixed")
6
+ * RBF(length_scale=100, length_scale_bounds="fixed") + WhiteKernel()
7
+ * ConstantKernel()*DotProduct(sigma_0=0, sigma_0_bounds="fixed") + WhiteKernel()