joel-woodfield commited on
Commit
88c5e8c
·
1 Parent(s): 2c1b454

Refactor code to use gr.State to prevent race conditions with multiple users

Browse files
Files changed (1) hide show
  1. regularization.py +453 -155
regularization.py CHANGED
@@ -30,10 +30,9 @@ logging.basicConfig(
30
  )
31
  logger = logging.getLogger("ELVIS")
32
 
 
33
  def min_corresponding_entries(W1, W2, w1, tol=0.1):
34
- #mask = np.isclose(W1, w1, atol=tol, rtol=0)
35
  mask = (W1 <= w1)
36
- #print(W1.max(), W1.min(), w1)
37
 
38
  values = W2[mask]
39
 
@@ -80,6 +79,19 @@ def l2_loss_regularization_path(y, X, regularization_type):
80
 
81
 
82
  class Regularization:
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  def __init__(self, width, height):
84
  # initialized in draw_plot
85
  #self.canvas_width = -1
@@ -89,64 +101,120 @@ class Regularization:
89
  self.canvas_height = height
90
 
91
  self.css ="""
92
- #my-button {
93
- height: 30px;
94
- font-size: 16px;
95
- }
96
-
97
- #rowheight {
98
- height: 90px;
99
- }
100
-
101
  .hidden-button {
102
  display: none;
103
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- .report-table {
106
- border: 0 !important;
107
- }
108
- .report-table tr, .report-table th, .report-table td, .report-table tbody, .report-table thead {
109
- border: 0 !important;
110
- padding: 6px 12px;
111
- text-align: center;
112
- }"""
113
-
114
- # Regularization strengths
115
- self.alphas = [0.01, 0.1, 1, 10, 100]
116
-
117
- self.losses = {
118
- #'l1': mean_absolute_error, # slow
119
- #'l1': lambda y, pred: np.mean(abs(y - pred)),
120
- # 'l1': l1_loss,
121
- #'l2': mean_squared_error, # slow
122
- #'l2': lambda y, pred: np.mean((y - pred)**2)
123
- # 'l2': l2_loss
124
- "l1": l1_loss,
125
- "l2": l2_loss,
126
- }
127
-
128
- self.regularizers = {
129
- # 'l1': lambda w: sum(abs(w)),
130
- # 'l2': np.linalg.norm
131
- "l1": l1_norm,
132
- "l2": l2_norm,
133
- }
134
-
135
- #self.Model = Ridge #l2 loss + l2 reg
136
- #self.Model = Lasso #l2 loss + l1 reg
137
-
138
- self.loss_type = 'l2'
139
- self.reg_type = 'l2'
140
-
141
- self.loss = self.losses[self.loss_type]
142
- self.regularizer = self.regularizers[self.reg_type]
143
-
144
- self.reg_levels = [10, 20, 30]
145
- self.w1_range = (-100, 100)
146
- self.w2_range = (-100, 100)
147
- self.num_dots = 500
148
-
149
- self.plot_regularization_path = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  def plot_regularization_contour(self):
152
  '''
@@ -264,25 +332,48 @@ class Regularization:
264
 
265
  return img
266
 
267
- def plot_data(self):
268
- # make sure the data is the same as the one used in plot_regularization_contour
269
- _, _, coef = make_regression(n_samples=200, n_features=2, noise=15, random_state=0, coef=True)
 
 
270
 
271
- x1 = np.linspace(-1, 1, 50)
272
- x2 = np.linspace(-1, 1, 50)
273
  mesh_x1, mesh_x2 = np.meshgrid(x1, x2)
274
  X = np.stack((mesh_x1.ravel(), mesh_x2.ravel()), axis=-1)
275
- y = X @ coef
276
 
277
  z = y.reshape(mesh_x1.shape)
278
 
279
- fig = go.Figure(data=go.Surface(
280
- z=z,
281
- x=mesh_x1,
282
- y=mesh_x2,
283
- colorscale='Viridis',
284
- opacity=0.8
285
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  fig.update_layout(
288
  title="Data",
@@ -297,12 +388,10 @@ class Regularization:
297
  )
298
  return fig
299
 
300
- def plot_strength_vs_weight(self):
301
- # make sure the data is the same as the one used in plot_regularization_contour
302
- X, y = make_regression(n_samples=200, n_features=2, noise=15, random_state=0)
303
  alphas = np.concat([np.zeros(1), np.logspace(-2, 2, 100)])
304
- if self.loss_type == "l2":
305
- l1_ratio = 1 if self.reg_type == "l1" else 0
306
  alphas, coefs, *_ = ElasticNet.path(X, y, l1_ratio=l1_ratio, alphas=alphas)
307
  else:
308
  coefs = np.random.randn(2, len(alphas)) # temporary
@@ -324,42 +413,68 @@ class Regularization:
324
 
325
  return img
326
 
327
- def update_loss(self, loss_type):
328
- self.loss_type = loss_type
329
- self.loss = self.losses[loss_type]
330
- return self.plot_regularization_contour(), self.plot_strength_vs_weight()
331
 
332
- def update_regularizer(self, reg_type):
333
- self.reg_type = reg_type
334
- self.regularizer = self.regularizers[reg_type]
 
335
 
336
- return self.plot_regularization_contour(), self.plot_strength_vs_weight()
 
 
337
 
338
- def update_reg_levels(self, reg_levels):
339
- self.reg_levels = [float(reg_level) for reg_level in reg_levels.split(",")]
 
340
 
341
- return self.plot_regularization_contour()
 
 
342
 
343
- def update_w1_range(self, w1_range):
344
- self.w1_range = [float(w1) for w1 in w1_range.split(",")]
345
- logger.info("Updated w1 range to " + str(self.w1_range))
346
 
347
- return self.plot_regularization_contour()
 
348
 
349
- def update_w2_range(self, w2_range):
350
- self.w2_range = [float(w2) for w2 in w2_range.split(",")]
351
- logger.info("Updated w2 range to " + str(self.w2_range))
352
-
353
- return self.plot_regularization_contour()
 
 
 
 
 
354
 
355
- def update_resolution(self, num_dots):
356
- self.num_dots = num_dots
357
- logger.info("updated resolution to " + str(num_dots))
358
- return self.plot_regularization_contour()
359
 
360
- def update_plot_path(self, plot_path):
361
- self.plot_regularization_path = plot_path
362
- return self.plot_regularization_contour()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
  def launch(self):
365
  # build the Gradio interface
@@ -367,57 +482,102 @@ class Regularization:
367
  # app title
368
  gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Regularization visualizer</div>")
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  # GUI elements and layout
371
  with gr.Row():
372
  with gr.Column(scale=2):
373
- with gr.Tab("Regularization contour"):
374
- self.regularization_contour = gr.Image(value=self.plot_regularization_contour(), container=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  with gr.Tab("Data"):
376
- self.data_3d_plot = gr.Plot(value=self.plot_data(), container=True)
 
 
377
  with gr.Tab("Strength vs weight"):
378
- self.strength_vs_weight = gr.Image(value=self.plot_strength_vs_weight(), container=True)
 
 
 
 
 
379
 
380
  with gr.Column(scale=1):
381
  with gr.Tab("Settings"):
382
- dataset_radio = gr.Radio(["make_regression", "Upload"],
383
- value="make_regression", label="Dataset type", elem_id="rowheight")
384
-
385
- # upload data
386
- file_chooser = gr.File(label="Choose a file", visible=False, elem_id="rowheight")
387
- self.file_chooser = file_chooser
388
-
389
- # loss type
390
- loss_type = gr.Dropdown(choices=['l1', 'l2'],
391
- label='Loss type',
392
- value='l2',
393
- visible=True)
394
 
395
  with gr.Row():
396
- # regularizer type
397
- regularizer_type = gr.Dropdown(choices=['l1', 'l2'],
398
- label='Regularizer type',
399
- value='l2',
400
- visible=True)
401
-
402
- # regularization strength
403
- #reg_textbox = gr.Textbox(label="Regularization constants")
404
- reg_textbox = gr.Textbox(label="Regularizer levels",
405
- value="10, 20, 30",
406
- interactive=True)
 
407
  self.reg_textbox = reg_textbox
408
 
409
  with gr.Row():
410
- # parameter value ranges
411
- w1_textbox = gr.Textbox(label="w1 range",
412
- value="-100, 100",
413
- interactive=True)
414
-
415
- w2_textbox = gr.Textbox(label="w2 range",
416
- value="-100, 100",
417
- interactive=True)
418
-
419
- # resolution
420
- slider = gr.Slider(minimum=100, maximum=1000, value=500, step=1, label="Resolution (#points)")
 
 
 
 
 
 
 
 
421
 
422
  # plot path
423
  path_checkbox = gr.Checkbox(label="Show regularization path", value=False)
@@ -438,34 +598,172 @@ class Regularization:
438
 
439
  with gr.Tab("Usage"):
440
  gr.Markdown(''.join(open('usage.md', 'r').readlines()))
441
-
442
 
443
  # event handlers for GUI elements
444
- loss_type.change(
445
- fn=self.update_loss,
446
- inputs=loss_type,
447
- outputs=(self.regularization_contour, self.strength_vs_weight),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  )
449
 
450
- regularizer_type.change(
451
  fn=self.update_regularizer,
452
- inputs=regularizer_type,
453
- outputs=(self.regularization_contour, self.strength_vs_weight),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  )
455
 
456
- reg_textbox.submit(self.update_reg_levels, inputs=reg_textbox,
457
- outputs=self.regularization_contour)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
 
459
- w1_textbox.submit(self.update_w1_range, inputs=w1_textbox,
460
- outputs=self.regularization_contour)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
- w2_textbox.submit(self.update_w2_range, inputs=w2_textbox,
463
- outputs=self.regularization_contour)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
- slider.change(self.update_resolution, inputs=slider, outputs=self.regularization_contour)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
  path_checkbox.change(
468
- self.update_plot_path, inputs=path_checkbox, outputs=self.regularization_contour
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  )
470
 
471
  demo.launch()
 
30
  )
31
  logger = logging.getLogger("ELVIS")
32
 
33
+
34
  def min_corresponding_entries(W1, W2, w1, tol=0.1):
 
35
  mask = (W1 <= w1)
 
36
 
37
  values = W2[mask]
38
 
 
79
 
80
 
81
  class Regularization:
82
+ LOSS_TYPES = ['l1', 'l2']
83
+ REGULARIZER_TYPES = ['l1', 'l2']
84
+
85
+ LOSS_FUNCTIONS = {
86
+ 'l1': l1_loss,
87
+ 'l2': l2_loss,
88
+ }
89
+
90
+ REGULARIZER_FUNCTIONS = {
91
+ 'l1': l1_norm,
92
+ 'l2': l2_norm,
93
+ }
94
+
95
  def __init__(self, width, height):
96
  # initialized in draw_plot
97
  #self.canvas_width = -1
 
101
  self.canvas_height = height
102
 
103
  self.css ="""
 
 
 
 
 
 
 
 
 
104
  .hidden-button {
105
  display: none;
106
  }
107
+ """
108
+
109
+ def compute_and_plot_loss_and_reg(
110
+ self,
111
+ X: np.ndarray,
112
+ y: np.ndarray,
113
+ loss_type: str,
114
+ reg_type: str,
115
+ reg_levels: list,
116
+ w1_range: list,
117
+ w2_range: list,
118
+ num_dots: int,
119
+ plot_path: bool,
120
+ ):
121
+ W1, W2 = self._build_parameter_grid(
122
+ w1_range, w2_range, num_dots
123
+ )
124
 
125
+ losses = self._compute_losses(
126
+ X, y, loss_type, W1, W2
127
+ )
128
+
129
+ reg_values = self._compute_reg_values(
130
+ W1, W2, reg_type
131
+ )
132
+
133
+ loss_levels = [
134
+ min_corresponding_entries(
135
+ reg_values, losses, reg_level
136
+ )
137
+ for reg_level in reg_levels
138
+ ]
139
+ loss_levels.reverse()
140
+
141
+ if plot_path:
142
+ if loss_type == "l2":
143
+ path_w = l2_loss_regularization_path(y, X, regularization_type=reg_type)
144
+ else:
145
+ min_loss_reg = reg_values.ravel()[np.argmin(losses)]
146
+ path_reg_levels = np.linspace(0, min_loss_reg, 20)
147
+ path_w = []
148
+ for reg_level in path_reg_levels:
149
+ mask = reg_values <= reg_level
150
+ if np.sum(mask) == 0:
151
+ continue
152
+ idx = np.argmin(losses[mask])
153
+ path_w.append(
154
+ np.stack((W1, W2), axis=-1)[mask][idx]
155
+ )
156
+
157
+ path_w = np.array(path_w)
158
+ else:
159
+ path_w = None
160
+
161
+ return self.plot_loss_and_reg(
162
+ W1,
163
+ W2,
164
+ losses,
165
+ reg_values,
166
+ loss_levels,
167
+ reg_levels,
168
+ path_w,
169
+ )
170
+
171
+ def plot_loss_and_reg(
172
+ self,
173
+ W1: np.ndarray,
174
+ W2: np.ndarray,
175
+ losses: np.ndarray,
176
+ reg_values: np.ndarray,
177
+ loss_levels: list,
178
+ reg_levels: list,
179
+ path_w: np.ndarray | None,
180
+ ):
181
+ fig, ax = plt.subplots(figsize=(8, 8))
182
+ ax.set_title("")
183
+ ax.set_xlabel("w1")
184
+ ax.set_ylabel("w2")
185
+
186
+ cmap = plt.get_cmap("viridis")
187
+ N = len(reg_levels)
188
+ colors = [cmap(i / (N - 1)) for i in range(N)]
189
+
190
+ # regularizer contours
191
+ cs1 = ax.contour(W1, W2, reg_values, levels=reg_levels, colors=colors, linestyles="dashed")
192
+ ax.clabel(cs1, inline=True, fontsize=8) # show contour levels
193
+
194
+ # loss contours
195
+ cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
196
+ ax.clabel(cs2, inline=True, fontsize=8)
197
+
198
+ # regularization path
199
+ if path_w is not None:
200
+ ax.plot(path_w[:, 0], path_w[:, 1], "r-")
201
+
202
+ # legend
203
+ loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
204
+ reg_line = mlines.Line2D([], [], color='black', linestyle='--', label='regularization')
205
+ handles = [loss_line, reg_line]
206
+ if path_w is not None:
207
+ path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
208
+ handles.append(path_line)
209
+ ax.legend(handles=handles)
210
+
211
+ buf = io.BytesIO()
212
+ fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
213
+ plt.close(fig)
214
+ buf.seek(0)
215
+ img = Image.open(buf)
216
+
217
+ return img
218
 
219
  def plot_regularization_contour(self):
220
  '''
 
332
 
333
  return img
334
 
335
+ def plot_data(self, X_data: np.ndarray, y_data: np.ndarray, coefs: np.ndarray):
336
+ x1_min = X_data[:, 0].min() - 1
337
+ x1_max = X_data[:, 0].max() + 1
338
+ x2_min = X_data[:, 1].min() - 1
339
+ x2_max = X_data[:, 1].max() + 1
340
 
341
+ x1 = np.linspace(x1_min, x1_max, 100)
342
+ x2 = np.linspace(x2_min, x2_max, 100)
343
  mesh_x1, mesh_x2 = np.meshgrid(x1, x2)
344
  X = np.stack((mesh_x1.ravel(), mesh_x2.ravel()), axis=-1)
345
+ y = X @ coefs
346
 
347
  z = y.reshape(mesh_x1.shape)
348
 
349
+ fig = go.Figure()
350
+
351
+ fig.add_trace(
352
+ go.Surface(
353
+ z=z,
354
+ x=mesh_x1,
355
+ y=mesh_x2,
356
+ colorscale='Viridis',
357
+ opacity=0.8,
358
+ name='True function',
359
+ )
360
+ )
361
+
362
+ fig.add_trace(
363
+ go.Scatter3d(
364
+ x=X_data[:, 0],
365
+ y=X_data[:, 1],
366
+ z=y_data,
367
+ mode='markers',
368
+ marker=dict(
369
+ size=3,
370
+ color='red',
371
+ opacity=0.8,
372
+ symbol='circle',
373
+ ),
374
+ name='Data Points',
375
+ )
376
+ )
377
 
378
  fig.update_layout(
379
  title="Data",
 
388
  )
389
  return fig
390
 
391
+ def plot_strength_vs_weight(self, X: np.ndarray, y: np.ndarray, loss_type: str, reg_type: str):
 
 
392
  alphas = np.concat([np.zeros(1), np.logspace(-2, 2, 100)])
393
+ if loss_type == "l2":
394
+ l1_ratio = 1 if reg_type == "l1" else 0
395
  alphas, coefs, *_ = ElasticNet.path(X, y, l1_ratio=l1_ratio, alphas=alphas)
396
  else:
397
  coefs = np.random.randn(2, len(alphas)) # temporary
 
413
 
414
  return img
415
 
416
+ def update_loss_type(self, loss_type: str):
417
+ if loss_type not in self.LOSS_TYPES:
418
+ raise ValueError(f"loss_type must be one of {self.LOSS_TYPES}")
419
+ return loss_type
420
 
421
+ def update_regularizer(self, reg_type: str):
422
+ if reg_type not in self.REGULARIZER_TYPES:
423
+ raise ValueError(f"reg_type must be one of {self.REGULARIZER_TYPES}")
424
+ return reg_type
425
 
426
+ def update_reg_levels(self, reg_levels_input: str):
427
+ reg_levels = [float(reg_level) for reg_level in reg_levels_input.split(",")]
428
+ return reg_levels
429
 
430
+ def update_w1_range(self, w1_range_input: str):
431
+ w1_range = [float(w1) for w1 in w1_range_input.split(",")]
432
+ return w1_range
433
 
434
+ def update_w2_range(self, w2_range_input: str):
435
+ w2_range = [float(w2) for w2 in w2_range_input.split(",")]
436
+ return w2_range
437
 
438
+ def update_resolution(self, num_dots: int):
439
+ return num_dots
 
440
 
441
+ def update_plot_path(self, plot_path: bool):
442
+ return plot_path
443
 
444
+ def _build_parameter_grid(
445
+ self,
446
+ w1_range: list,
447
+ w2_range: list,
448
+ num_dots: int,
449
+ ) -> tuple[np.ndarray, np.ndarray]:
450
+ # build grid in parameter space
451
+ w1 = np.linspace(w1_range[0], w1_range[1], num_dots)
452
+ w2 = np.linspace(w2_range[0], w2_range[1], num_dots)
453
+ W1, W2 = np.meshgrid(w1, w2)
454
 
455
+ return W1, W2
 
 
 
456
 
457
+ def _compute_losses(
458
+ self,
459
+ X: np.ndarray,
460
+ y: np.ndarray,
461
+ loss_type: str,
462
+ W1: np.ndarray,
463
+ W2: np.ndarray,
464
+ ) -> np.ndarray:
465
+ stacked = np.stack((W1, W2), axis=-1)
466
+ losses = self.LOSS_FUNCTIONS[loss_type](stacked, y, X)
467
+ return losses
468
+
469
+ def _compute_reg_values(
470
+ self,
471
+ W1: np.ndarray,
472
+ W2: np.ndarray,
473
+ reg_type: str,
474
+ ) -> np.ndarray:
475
+ stacked = np.stack((W1, W2), axis=-1)
476
+ regs = self.REGULARIZER_FUNCTIONS[reg_type](stacked)
477
+ return regs
478
 
479
  def launch(self):
480
  # build the Gradio interface
 
482
  # app title
483
  gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Regularization visualizer</div>")
484
 
485
+ # states
486
+ loss_type = gr.State("l2")
487
+ reg_type = gr.State("l2")
488
+ reg_levels = gr.State([10, 20, 30])
489
+ w1_range = gr.State([-100, 100])
490
+ w2_range = gr.State([-100, 100])
491
+ num_dots = gr.State(500)
492
+ plot_regularization_path = gr.State(False)
493
+
494
+ X, y, coefs = make_regression(
495
+ n_samples=200, n_features=2, noise=15, random_state=0, coef=True
496
+ )
497
+ X = gr.State(X)
498
+ y = gr.State(y)
499
+ coefs = gr.State(coefs)
500
+
501
  # GUI elements and layout
502
  with gr.Row():
503
  with gr.Column(scale=2):
504
+ with gr.Tab("Loss and Regularization"):
505
+ self.loss_and_regularization_plot = gr.Image(
506
+ value=self.compute_and_plot_loss_and_reg(
507
+ X.value,
508
+ y.value,
509
+ loss_type.value,
510
+ reg_type.value,
511
+ reg_levels.value,
512
+ w1_range.value,
513
+ w2_range.value,
514
+ num_dots.value,
515
+ plot_regularization_path.value,
516
+ ),
517
+ container=True,
518
+ )
519
  with gr.Tab("Data"):
520
+ self.data_3d_plot = gr.Plot(
521
+ value=self.plot_data(X.value, y.value, coefs.value), container=True
522
+ )
523
  with gr.Tab("Strength vs weight"):
524
+ self.strength_vs_weight = gr.Image(
525
+ value=self.plot_strength_vs_weight(
526
+ X.value, y.value, loss_type.value, reg_type.value
527
+ ),
528
+ container=True,
529
+ )
530
 
531
  with gr.Column(scale=1):
532
  with gr.Tab("Settings"):
533
+ dataset_radio = gr.Radio(
534
+ ["make_regression", "Upload"],
535
+ value="make_regression",
536
+ label="Dataset type",
537
+ )
538
+
539
+ loss_type_selection = gr.Dropdown(
540
+ choices=['l1', 'l2'],
541
+ label='Loss type',
542
+ value='l2',
543
+ visible=True,
544
+ )
545
 
546
  with gr.Row():
547
+ regularizer_type_selection = gr.Dropdown(
548
+ choices=['l1', 'l2'],
549
+ label='Regularizer type',
550
+ value='l2',
551
+ visible=True,
552
+ )
553
+
554
+ reg_textbox = gr.Textbox(
555
+ label="Regularizer levels",
556
+ value="10, 20, 30",
557
+ interactive=True,
558
+ )
559
  self.reg_textbox = reg_textbox
560
 
561
  with gr.Row():
562
+ w1_textbox = gr.Textbox(
563
+ label="w1 range",
564
+ value="-100, 100",
565
+ interactive=True,
566
+ )
567
+
568
+ w2_textbox = gr.Textbox(
569
+ label="w2 range",
570
+ value="-100, 100",
571
+ interactive=True,
572
+ )
573
+
574
+ resolution_slider = gr.Slider(
575
+ minimum=100,
576
+ maximum=1000,
577
+ value=500,
578
+ step=1,
579
+ label="Resolution (#points)",
580
+ )
581
 
582
  # plot path
583
  path_checkbox = gr.Checkbox(label="Show regularization path", value=False)
 
598
 
599
  with gr.Tab("Usage"):
600
  gr.Markdown(''.join(open('usage.md', 'r').readlines()))
 
601
 
602
  # event handlers for GUI elements
603
+ loss_type_selection.change(
604
+ fn=self.update_loss_type,
605
+ inputs=[loss_type_selection],
606
+ outputs=[loss_type],
607
+ ).then(
608
+ fn=self.compute_and_plot_loss_and_reg,
609
+ inputs=[
610
+ X,
611
+ y,
612
+ loss_type,
613
+ reg_type,
614
+ reg_levels,
615
+ w1_range,
616
+ w2_range,
617
+ num_dots,
618
+ plot_regularization_path,
619
+ ],
620
+ outputs=self.loss_and_regularization_plot,
621
+ ).then(
622
+ fn=self.plot_strength_vs_weight,
623
+ inputs=[
624
+ X,
625
+ y,
626
+ loss_type,
627
+ reg_type,
628
+ ],
629
+ outputs=self.strength_vs_weight,
630
  )
631
 
632
+ regularizer_type_selection.change(
633
  fn=self.update_regularizer,
634
+ inputs=[regularizer_type_selection],
635
+ outputs=[reg_type],
636
+ ).then(
637
+ fn=self.compute_and_plot_loss_and_reg,
638
+ inputs=[
639
+ X,
640
+ y,
641
+ loss_type,
642
+ reg_type,
643
+ reg_levels,
644
+ w1_range,
645
+ w2_range,
646
+ num_dots,
647
+ plot_regularization_path,
648
+ ],
649
+ outputs=self.loss_and_regularization_plot,
650
+ ).then(
651
+ fn=self.plot_strength_vs_weight,
652
+ inputs=[
653
+ X,
654
+ y,
655
+ loss_type,
656
+ reg_type,
657
+ ],
658
+ outputs=self.strength_vs_weight,
659
  )
660
 
661
+ reg_textbox.submit(
662
+ self.update_reg_levels,
663
+ inputs=[reg_textbox],
664
+ outputs=[reg_levels],
665
+ ).then(
666
+ fn=self.compute_and_plot_loss_and_reg,
667
+ inputs=[
668
+ X,
669
+ y,
670
+ loss_type,
671
+ reg_type,
672
+ reg_levels,
673
+ w1_range,
674
+ w2_range,
675
+ num_dots,
676
+ plot_regularization_path,
677
+ ],
678
+ outputs=self.loss_and_regularization_plot,
679
+ ).then(
680
+ fn=self.plot_strength_vs_weight,
681
+ inputs=[
682
+ X,
683
+ y,
684
+ loss_type,
685
+ reg_type,
686
+ ],
687
+ outputs=self.strength_vs_weight,
688
+ )
689
 
690
+ w1_textbox.submit(
691
+ self.update_w1_range,
692
+ inputs=[w1_textbox],
693
+ outputs=[w1_range],
694
+ ).then(
695
+ fn=self.compute_and_plot_loss_and_reg,
696
+ inputs=[
697
+ X,
698
+ y,
699
+ loss_type,
700
+ reg_type,
701
+ reg_levels,
702
+ w1_range,
703
+ w2_range,
704
+ num_dots,
705
+ plot_regularization_path,
706
+ ],
707
+ outputs=self.loss_and_regularization_plot,
708
+ )
709
 
710
+ w2_textbox.submit(
711
+ self.update_w2_range,
712
+ inputs=[w2_textbox],
713
+ outputs=[w2_range],
714
+ ).then(
715
+ fn=self.compute_and_plot_loss_and_reg,
716
+ inputs=[
717
+ X,
718
+ y,
719
+ loss_type,
720
+ reg_type,
721
+ reg_levels,
722
+ w1_range,
723
+ w2_range,
724
+ num_dots,
725
+ ],
726
+ outputs=self.loss_and_regularization_plot,
727
+ )
728
 
729
+ resolution_slider.change(
730
+ self.update_resolution,
731
+ inputs=[resolution_slider],
732
+ outputs=[num_dots],
733
+ ).then(
734
+ fn=self.compute_and_plot_loss_and_reg,
735
+ inputs=[
736
+ X,
737
+ y,
738
+ loss_type,
739
+ reg_type,
740
+ reg_levels,
741
+ w1_range,
742
+ w2_range,
743
+ num_dots,
744
+ plot_regularization_path,
745
+ ],
746
+ outputs=self.loss_and_regularization_plot,
747
+ )
748
 
749
  path_checkbox.change(
750
+ self.update_plot_path,
751
+ inputs=[path_checkbox],
752
+ outputs=[plot_regularization_path],
753
+ ).then(
754
+ fn=self.compute_and_plot_loss_and_reg,
755
+ inputs=[
756
+ X,
757
+ y,
758
+ loss_type,
759
+ reg_type,
760
+ reg_levels,
761
+ w1_range,
762
+ w2_range,
763
+ num_dots,
764
+ plot_regularization_path,
765
+ ],
766
+ outputs=self.loss_and_regularization_plot,
767
  )
768
 
769
  demo.launch()