joel-woodfield commited on
Commit
9e5bc2a
·
1 Parent(s): 1d316f4

Add a button to update setting changes

Browse files
Files changed (1) hide show
  1. regularization.py +53 -14
regularization.py CHANGED
@@ -263,7 +263,8 @@ class Regularization:
263
  l1_ratio = 1 if reg_type == "l1" else 0
264
  alphas, coefs, *_ = ElasticNet.path(X, y, l1_ratio=l1_ratio, alphas=alphas)
265
  else:
266
- coefs = np.random.randn(2, len(alphas)) # temporary
 
267
  coefs = coefs.T
268
 
269
  fig, ax = plt.subplots(figsize=(8, 8))
@@ -287,6 +288,9 @@ class Regularization:
287
  raise ValueError(f"loss_type must be one of {self.LOSS_TYPES}")
288
  return loss_type
289
 
 
 
 
290
  def update_regularizer(self, reg_type: str):
291
  if reg_type not in self.REGULARIZER_TYPES:
292
  raise ValueError(f"reg_type must be one of {self.REGULARIZER_TYPES}")
@@ -408,19 +412,20 @@ class Regularization:
408
  visible=True,
409
  )
410
 
411
- with gr.Row():
412
- regularizer_type_selection = gr.Dropdown(
413
- choices=['l1', 'l2'],
414
- label='Regularizer type',
415
- value='l2',
416
- visible=True,
417
- )
418
-
419
- reg_textbox = gr.Textbox(
420
- label="Regularizer levels",
421
- value="10, 20, 30",
422
- interactive=True,
423
- )
 
424
 
425
  with gr.Row():
426
  w1_textbox = gr.Textbox(
@@ -444,8 +449,11 @@ class Regularization:
444
  label="Resolution (#points)",
445
  )
446
 
 
 
447
  with gr.Row():
448
  path_checkbox = gr.Checkbox(label="Show regularization path", value=False)
 
449
  with gr.Tab("Data"):
450
  dataset_view = DatasetView()
451
  dataset_view.build(state=dataset)
@@ -496,6 +504,10 @@ class Regularization:
496
  fn=self.update_loss_type,
497
  inputs=[loss_type_selection],
498
  outputs=[loss_type],
 
 
 
 
499
  ).then(
500
  fn=self.compute_and_plot_loss_and_reg,
501
  inputs=[
@@ -611,6 +623,33 @@ class Regularization:
611
  outputs=self.loss_and_regularization_plot,
612
  )
613
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  resolution_slider.change(
615
  self.update_resolution,
616
  inputs=[resolution_slider],
 
263
  l1_ratio = 1 if reg_type == "l1" else 0
264
  alphas, coefs, *_ = ElasticNet.path(X, y, l1_ratio=l1_ratio, alphas=alphas)
265
  else:
266
+ return Image.new("RGB", (800, 800), color="white")
267
+
268
  coefs = coefs.T
269
 
270
  fig, ax = plt.subplots(figsize=(8, 8))
 
288
  raise ValueError(f"loss_type must be one of {self.LOSS_TYPES}")
289
  return loss_type
290
 
291
+ def update_reg_path_visibility(self, loss_type: str):
292
+ return gr.update(visible=(loss_type == "l2"))
293
+
294
  def update_regularizer(self, reg_type: str):
295
  if reg_type not in self.REGULARIZER_TYPES:
296
  raise ValueError(f"reg_type must be one of {self.REGULARIZER_TYPES}")
 
412
  visible=True,
413
  )
414
 
415
+ with gr.Group():
416
+ with gr.Row():
417
+ regularizer_type_selection = gr.Dropdown(
418
+ choices=['l1', 'l2'],
419
+ label='Regularizer type',
420
+ value='l2',
421
+ visible=True,
422
+ )
423
+
424
+ reg_textbox = gr.Textbox(
425
+ label="Regularizer levels",
426
+ value="10, 20, 30",
427
+ interactive=True,
428
+ )
429
 
430
  with gr.Row():
431
  w1_textbox = gr.Textbox(
 
449
  label="Resolution (#points)",
450
  )
451
 
452
+ submit_button = gr.Button("Submit changes")
453
+
454
  with gr.Row():
455
  path_checkbox = gr.Checkbox(label="Show regularization path", value=False)
456
+
457
  with gr.Tab("Data"):
458
  dataset_view = DatasetView()
459
  dataset_view.build(state=dataset)
 
504
  fn=self.update_loss_type,
505
  inputs=[loss_type_selection],
506
  outputs=[loss_type],
507
+ ).then(
508
+ fn=self.update_reg_path_visibility,
509
+ inputs=[loss_type_selection],
510
+ outputs=[path_checkbox],
511
  ).then(
512
  fn=self.compute_and_plot_loss_and_reg,
513
  inputs=[
 
623
  outputs=self.loss_and_regularization_plot,
624
  )
625
 
626
+ submit_button.click(
627
+ self.update_w1_range,
628
+ inputs=[w1_textbox],
629
+ outputs=[w1_range],
630
+ ).then(
631
+ self.update_w2_range,
632
+ inputs=[w2_textbox],
633
+ outputs=[w2_range],
634
+ ).then(
635
+ self.update_reg_levels,
636
+ inputs=[reg_textbox],
637
+ outputs=[reg_levels],
638
+ ).then(
639
+ fn=self.compute_and_plot_loss_and_reg,
640
+ inputs=[
641
+ dataset,
642
+ loss_type,
643
+ reg_type,
644
+ reg_levels,
645
+ w1_range,
646
+ w2_range,
647
+ num_dots,
648
+ plot_regularization_path,
649
+ ],
650
+ outputs=self.loss_and_regularization_plot,
651
+ )
652
+
653
  resolution_slider.change(
654
  self.update_resolution,
655
  inputs=[resolution_slider],