Spaces:
Sleeping
Sleeping
Commit ·
9e5bc2a
1
Parent(s): 1d316f4
Add a button to update setting changes
Browse files- regularization.py +53 -14
regularization.py
CHANGED
|
@@ -263,7 +263,8 @@ class Regularization:
|
|
| 263 |
l1_ratio = 1 if reg_type == "l1" else 0
|
| 264 |
alphas, coefs, *_ = ElasticNet.path(X, y, l1_ratio=l1_ratio, alphas=alphas)
|
| 265 |
else:
|
| 266 |
-
|
|
|
|
| 267 |
coefs = coefs.T
|
| 268 |
|
| 269 |
fig, ax = plt.subplots(figsize=(8, 8))
|
|
@@ -287,6 +288,9 @@ class Regularization:
|
|
| 287 |
raise ValueError(f"loss_type must be one of {self.LOSS_TYPES}")
|
| 288 |
return loss_type
|
| 289 |
|
|
|
|
|
|
|
|
|
|
| 290 |
def update_regularizer(self, reg_type: str):
|
| 291 |
if reg_type not in self.REGULARIZER_TYPES:
|
| 292 |
raise ValueError(f"reg_type must be one of {self.REGULARIZER_TYPES}")
|
|
@@ -408,19 +412,20 @@ class Regularization:
|
|
| 408 |
visible=True,
|
| 409 |
)
|
| 410 |
|
| 411 |
-
with gr.
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
|
|
|
| 424 |
|
| 425 |
with gr.Row():
|
| 426 |
w1_textbox = gr.Textbox(
|
|
@@ -444,8 +449,11 @@ class Regularization:
|
|
| 444 |
label="Resolution (#points)",
|
| 445 |
)
|
| 446 |
|
|
|
|
|
|
|
| 447 |
with gr.Row():
|
| 448 |
path_checkbox = gr.Checkbox(label="Show regularization path", value=False)
|
|
|
|
| 449 |
with gr.Tab("Data"):
|
| 450 |
dataset_view = DatasetView()
|
| 451 |
dataset_view.build(state=dataset)
|
|
@@ -496,6 +504,10 @@ class Regularization:
|
|
| 496 |
fn=self.update_loss_type,
|
| 497 |
inputs=[loss_type_selection],
|
| 498 |
outputs=[loss_type],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
).then(
|
| 500 |
fn=self.compute_and_plot_loss_and_reg,
|
| 501 |
inputs=[
|
|
@@ -611,6 +623,33 @@ class Regularization:
|
|
| 611 |
outputs=self.loss_and_regularization_plot,
|
| 612 |
)
|
| 613 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
resolution_slider.change(
|
| 615 |
self.update_resolution,
|
| 616 |
inputs=[resolution_slider],
|
|
|
|
| 263 |
l1_ratio = 1 if reg_type == "l1" else 0
|
| 264 |
alphas, coefs, *_ = ElasticNet.path(X, y, l1_ratio=l1_ratio, alphas=alphas)
|
| 265 |
else:
|
| 266 |
+
return Image.new("RGB", (800, 800), color="white")
|
| 267 |
+
|
| 268 |
coefs = coefs.T
|
| 269 |
|
| 270 |
fig, ax = plt.subplots(figsize=(8, 8))
|
|
|
|
| 288 |
raise ValueError(f"loss_type must be one of {self.LOSS_TYPES}")
|
| 289 |
return loss_type
|
| 290 |
|
| 291 |
+
def update_reg_path_visibility(self, loss_type: str):
|
| 292 |
+
return gr.update(visible=(loss_type == "l2"))
|
| 293 |
+
|
| 294 |
def update_regularizer(self, reg_type: str):
|
| 295 |
if reg_type not in self.REGULARIZER_TYPES:
|
| 296 |
raise ValueError(f"reg_type must be one of {self.REGULARIZER_TYPES}")
|
|
|
|
| 412 |
visible=True,
|
| 413 |
)
|
| 414 |
|
| 415 |
+
with gr.Group():
|
| 416 |
+
with gr.Row():
|
| 417 |
+
regularizer_type_selection = gr.Dropdown(
|
| 418 |
+
choices=['l1', 'l2'],
|
| 419 |
+
label='Regularizer type',
|
| 420 |
+
value='l2',
|
| 421 |
+
visible=True,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
reg_textbox = gr.Textbox(
|
| 425 |
+
label="Regularizer levels",
|
| 426 |
+
value="10, 20, 30",
|
| 427 |
+
interactive=True,
|
| 428 |
+
)
|
| 429 |
|
| 430 |
with gr.Row():
|
| 431 |
w1_textbox = gr.Textbox(
|
|
|
|
| 449 |
label="Resolution (#points)",
|
| 450 |
)
|
| 451 |
|
| 452 |
+
submit_button = gr.Button("Submit changes")
|
| 453 |
+
|
| 454 |
with gr.Row():
|
| 455 |
path_checkbox = gr.Checkbox(label="Show regularization path", value=False)
|
| 456 |
+
|
| 457 |
with gr.Tab("Data"):
|
| 458 |
dataset_view = DatasetView()
|
| 459 |
dataset_view.build(state=dataset)
|
|
|
|
| 504 |
fn=self.update_loss_type,
|
| 505 |
inputs=[loss_type_selection],
|
| 506 |
outputs=[loss_type],
|
| 507 |
+
).then(
|
| 508 |
+
fn=self.update_reg_path_visibility,
|
| 509 |
+
inputs=[loss_type_selection],
|
| 510 |
+
outputs=[path_checkbox],
|
| 511 |
).then(
|
| 512 |
fn=self.compute_and_plot_loss_and_reg,
|
| 513 |
inputs=[
|
|
|
|
| 623 |
outputs=self.loss_and_regularization_plot,
|
| 624 |
)
|
| 625 |
|
| 626 |
+
submit_button.click(
|
| 627 |
+
self.update_w1_range,
|
| 628 |
+
inputs=[w1_textbox],
|
| 629 |
+
outputs=[w1_range],
|
| 630 |
+
).then(
|
| 631 |
+
self.update_w2_range,
|
| 632 |
+
inputs=[w2_textbox],
|
| 633 |
+
outputs=[w2_range],
|
| 634 |
+
).then(
|
| 635 |
+
self.update_reg_levels,
|
| 636 |
+
inputs=[reg_textbox],
|
| 637 |
+
outputs=[reg_levels],
|
| 638 |
+
).then(
|
| 639 |
+
fn=self.compute_and_plot_loss_and_reg,
|
| 640 |
+
inputs=[
|
| 641 |
+
dataset,
|
| 642 |
+
loss_type,
|
| 643 |
+
reg_type,
|
| 644 |
+
reg_levels,
|
| 645 |
+
w1_range,
|
| 646 |
+
w2_range,
|
| 647 |
+
num_dots,
|
| 648 |
+
plot_regularization_path,
|
| 649 |
+
],
|
| 650 |
+
outputs=self.loss_and_regularization_plot,
|
| 651 |
+
)
|
| 652 |
+
|
| 653 |
resolution_slider.change(
|
| 654 |
self.update_resolution,
|
| 655 |
inputs=[resolution_slider],
|