joel-woodfield commited on
Commit
9e75a75
·
1 Parent(s): ba3e474

Add regularization path

Browse files
Files changed (1) hide show
  1. regularization.py +33 -2
regularization.py CHANGED
@@ -51,6 +51,8 @@ class Regularization:
51
  self.canvas_width = width
52
  self.canvas_height = height
53
 
 
 
54
  self.css ="""
55
  #my-button {
56
  height: 30px;
@@ -183,10 +185,29 @@ class Regularization:
183
  cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
184
  ax.clabel(cs2, inline=True, fontsize=8)
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  # custom legend
187
  loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
188
  reg_line = mlines.Line2D([], [], color='black', linestyle='--', label='regularization')
189
- ax.legend(handles=[loss_line, reg_line])
 
 
 
 
190
 
191
  # plot solutions
192
  #for alpha, w, norm, mse in solutions:
@@ -229,12 +250,15 @@ class Regularization:
229
 
230
  return self.plot()
231
 
232
-
233
  def update_resolution(self, num_dots):
234
  self.num_dots = num_dots
235
  logger.info("updated resolution to " + str(num_dots))
236
  return self.plot()
237
 
 
 
 
 
238
  def launch(self):
239
  # build the Gradio interface
240
  with gr.Blocks(css=self.css) as demo:
@@ -288,6 +312,9 @@ class Regularization:
288
  # resolution
289
  slider = gr.Slider(minimum=100, maximum=1000, value=100, step=1, label="Resolution (#points)")
290
 
 
 
 
291
  with gr.Tab("Export"):
292
  # use hidden download button to generate files on the fly
293
  # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
@@ -323,6 +350,10 @@ class Regularization:
323
 
324
  slider.change(self.update_resolution, inputs=slider, outputs=self.data_image)
325
 
 
 
 
 
326
  demo.launch()
327
 
328
  visualizer = Regularization(width=1200, height=900)
 
51
  self.canvas_width = width
52
  self.canvas_height = height
53
 
54
+ self.plot_regularization_path = False
55
+
56
  self.css ="""
57
  #my-button {
58
  height: 30px;
 
185
  cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
186
  ax.clabel(cs2, inline=True, fontsize=8)
187
 
188
+ # plot path
189
+ if self.plot_regularization_path:
190
+ min_loss_reg = regs.ravel()[np.argmin(losses)]
191
+ path_reg_levels = np.linspace(0, min_loss_reg, 20)
192
+ path_w = []
193
+ for reg_level in path_reg_levels:
194
+ mask = regs <= reg_level
195
+ if np.sum(mask) == 0:
196
+ continue
197
+ idx = np.argmin(losses[mask])
198
+ path_w.append(stacked[mask][idx])
199
+
200
+ path_w = np.array(path_w)
201
+ ax.plot(path_w[:, 0], path_w[:, 1], "r-")
202
+
203
  # custom legend
204
  loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
205
  reg_line = mlines.Line2D([], [], color='black', linestyle='--', label='regularization')
206
+ handles = [loss_line, reg_line]
207
+ if self.plot_regularization_path:
208
+ path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
209
+ handles.append(path_line)
210
+ ax.legend(handles=handles)
211
 
212
  # plot solutions
213
  #for alpha, w, norm, mse in solutions:
 
250
 
251
  return self.plot()
252
 
 
253
  def update_resolution(self, num_dots):
254
  self.num_dots = num_dots
255
  logger.info("updated resolution to " + str(num_dots))
256
  return self.plot()
257
 
258
+ def update_plot_path(self, plot_path):
259
+ self.plot_regularization_path = plot_path
260
+ return self.plot()
261
+
262
  def launch(self):
263
  # build the Gradio interface
264
  with gr.Blocks(css=self.css) as demo:
 
312
  # resolution
313
  slider = gr.Slider(minimum=100, maximum=1000, value=100, step=1, label="Resolution (#points)")
314
 
315
+ # plot path
316
+ path_checkbox = gr.Checkbox(label="Show regularization path", value=False)
317
+
318
  with gr.Tab("Export"):
319
  # use hidden download button to generate files on the fly
320
  # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
 
350
 
351
  slider.change(self.update_resolution, inputs=slider, outputs=self.data_image)
352
 
353
+ path_checkbox.change(
354
+ self.update_plot_path, inputs=path_checkbox, outputs=self.data_image
355
+ )
356
+
357
  demo.launch()
358
 
359
  visualizer = Regularization(width=1200, height=900)