joel-woodfield commited on
Commit
1f2626d
·
1 Parent(s): dfbaa58

Split the regularization and data plots into two tabs

Browse files
Files changed (1) hide show
  1. regularization.py +48 -33
regularization.py CHANGED
@@ -134,7 +134,7 @@ class Regularization:
134
 
135
  self.plot_regularization_path = False
136
 
137
- def plot(self):
138
  '''
139
  '''
140
 
@@ -192,23 +192,22 @@ class Regularization:
192
  print(loss_levels)
193
 
194
  # plot contour plots
195
- # fig = plt.figure(figsize=(5, 5))
196
- fig, axs = plt.subplots(1, 2, figsize=(10, 5))
197
- axs[0].set_title("")
198
- axs[0].set_xlabel("w1")
199
- axs[0].set_ylabel("w2")
200
 
201
  cmap = plt.get_cmap("viridis")
202
  N = len(reg_levels)
203
  colors = [cmap(i / (N - 1)) for i in range(N)]
204
 
205
  # regularizer contours
206
- cs1 = axs[0].contour(W1, W2, regs, levels=reg_levels, colors=colors, linestyles="dashed")
207
- axs[0].clabel(cs1, inline=True, fontsize=8) # show contour levels
208
 
209
  # loss contours
210
- cs2 = axs[0].contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
211
- axs[0].clabel(cs2, inline=True, fontsize=8)
212
 
213
  # plot path
214
  if self.plot_regularization_path:
@@ -223,7 +222,7 @@ class Regularization:
223
  path_w.append(stacked[mask][idx])
224
 
225
  path_w = np.array(path_w)
226
- axs[0].plot(path_w[:, 0], path_w[:, 1], "r-")
227
 
228
  # custom legend
229
  loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
@@ -232,13 +231,7 @@ class Regularization:
232
  if self.plot_regularization_path:
233
  path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
234
  handles.append(path_line)
235
- axs[0].legend(handles=handles)
236
-
237
- # plot data points
238
- axs[1].set_xlabel("X1")
239
- axs[1].set_ylabel("X2")
240
- sc = axs[1].scatter(X[:, 0], X[:, 1], c=y, cmap="viridis")
241
- fig.colorbar(sc, ax=axs[1], label="y")
242
 
243
  # plot solutions
244
  #for alpha, w, norm, mse in solutions:
@@ -253,42 +246,61 @@ class Regularization:
253
 
254
  return img
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  def update_loss(self, loss_type):
257
  self.loss_type = loss_type
258
  self.loss = self.losses[loss_type]
259
- return self.plot()
260
 
261
  def update_regularizer(self, reg_type):
262
  self.reg_type = reg_type
263
  self.regularizer = self.regularizers[reg_type]
264
 
265
- return self.plot()
266
 
267
  def update_reg_levels(self, reg_levels):
268
  self.reg_levels = [float(reg_level) for reg_level in reg_levels.split(",")]
269
 
270
- return self.plot()
271
 
272
  def update_w1_range(self, w1_range):
273
  self.w1_range = [float(w1) for w1 in w1_range.split(",")]
274
  logger.info("Updated w1 range to " + str(self.w1_range))
275
 
276
- return self.plot()
277
 
278
  def update_w2_range(self, w2_range):
279
  self.w2_range = [float(w2) for w2 in w2_range.split(",")]
280
  logger.info("Updated w2 range to " + str(self.w2_range))
281
 
282
- return self.plot()
283
 
284
  def update_resolution(self, num_dots):
285
  self.num_dots = num_dots
286
  logger.info("updated resolution to " + str(num_dots))
287
- return self.plot()
288
 
289
  def update_plot_path(self, plot_path):
290
  self.plot_regularization_path = plot_path
291
- return self.plot()
292
 
293
  def launch(self):
294
  # build the Gradio interface
@@ -299,7 +311,10 @@ class Regularization:
299
  # GUI elements and layout
300
  with gr.Row():
301
  with gr.Column(scale=2):
302
- self.data_image = gr.Image(value=self.plot(), container=True)
 
 
 
303
 
304
  with gr.Column(scale=1):
305
  with gr.Tab("Settings"):
@@ -366,23 +381,23 @@ class Regularization:
366
 
367
  # event handlers for GUI elements
368
  loss_type.change(fn=self.update_loss, inputs=loss_type,
369
- outputs=self.data_image)
370
  regularizer_type.change(fn=self.update_regularizer,
371
- inputs=regularizer_type, outputs=self.data_image)
372
 
373
  reg_textbox.submit(self.update_reg_levels, inputs=reg_textbox,
374
- outputs=self.data_image)
375
 
376
  w1_textbox.submit(self.update_w1_range, inputs=w1_textbox,
377
- outputs=self.data_image)
378
 
379
  w2_textbox.submit(self.update_w2_range, inputs=w2_textbox,
380
- outputs=self.data_image)
381
 
382
- slider.change(self.update_resolution, inputs=slider, outputs=self.data_image)
383
 
384
  path_checkbox.change(
385
- self.update_plot_path, inputs=path_checkbox, outputs=self.data_image
386
  )
387
 
388
  demo.launch()
 
134
 
135
  self.plot_regularization_path = False
136
 
137
+ def plot_regularization_contour(self):
138
  '''
139
  '''
140
 
 
192
  print(loss_levels)
193
 
194
  # plot contour plots
195
+ fig, ax = plt.subplots(figsize=(8, 8))
196
+ ax.set_title("")
197
+ ax.set_xlabel("w1")
198
+ ax.set_ylabel("w2")
 
199
 
200
  cmap = plt.get_cmap("viridis")
201
  N = len(reg_levels)
202
  colors = [cmap(i / (N - 1)) for i in range(N)]
203
 
204
  # regularizer contours
205
+ cs1 = ax.contour(W1, W2, regs, levels=reg_levels, colors=colors, linestyles="dashed")
206
+ ax.clabel(cs1, inline=True, fontsize=8) # show contour levels
207
 
208
  # loss contours
209
+ cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
210
+ ax.clabel(cs2, inline=True, fontsize=8)
211
 
212
  # plot path
213
  if self.plot_regularization_path:
 
222
  path_w.append(stacked[mask][idx])
223
 
224
  path_w = np.array(path_w)
225
+ ax.plot(path_w[:, 0], path_w[:, 1], "r-")
226
 
227
  # custom legend
228
  loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
 
231
  if self.plot_regularization_path:
232
  path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
233
  handles.append(path_line)
234
+ ax.legend(handles=handles)
 
 
 
 
 
 
235
 
236
  # plot solutions
237
  #for alpha, w, norm, mse in solutions:
 
246
 
247
  return img
248
 
249
+ def plot_data(self):
250
+ # make sure the data is the same as the one used in plot_regularization_contour
251
+ X, y = make_regression(n_samples=200, n_features=2, noise=15, random_state=0)
252
+ fig, ax = plt.subplots(figsize=(8, 8))
253
+
254
+ # plot data points
255
+ ax.set_xlabel("X1")
256
+ ax.set_ylabel("X2")
257
+ sc = ax.scatter(X[:, 0], X[:, 1], c=y, cmap="viridis")
258
+ fig.colorbar(sc, ax=ax, label="y")
259
+
260
+ buf = io.BytesIO()
261
+ fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
262
+ plt.close(fig)
263
+ buf.seek(0)
264
+ img = Image.open(buf)
265
+
266
+ return img
267
+
268
  def update_loss(self, loss_type):
269
  self.loss_type = loss_type
270
  self.loss = self.losses[loss_type]
271
+ return self.plot_regularization_contour()
272
 
273
  def update_regularizer(self, reg_type):
274
  self.reg_type = reg_type
275
  self.regularizer = self.regularizers[reg_type]
276
 
277
+ return self.plot_regularization_contour()
278
 
279
  def update_reg_levels(self, reg_levels):
280
  self.reg_levels = [float(reg_level) for reg_level in reg_levels.split(",")]
281
 
282
+ return self.plot_regularization_contour()
283
 
284
  def update_w1_range(self, w1_range):
285
  self.w1_range = [float(w1) for w1 in w1_range.split(",")]
286
  logger.info("Updated w1 range to " + str(self.w1_range))
287
 
288
+ return self.plot_regularization_contour()
289
 
290
  def update_w2_range(self, w2_range):
291
  self.w2_range = [float(w2) for w2 in w2_range.split(",")]
292
  logger.info("Updated w2 range to " + str(self.w2_range))
293
 
294
+ return self.plot_regularization_contour()
295
 
296
  def update_resolution(self, num_dots):
297
  self.num_dots = num_dots
298
  logger.info("updated resolution to " + str(num_dots))
299
+ return self.plot_regularization_contour()
300
 
301
  def update_plot_path(self, plot_path):
302
  self.plot_regularization_path = plot_path
303
+ return self.plot_regularization_contour()
304
 
305
  def launch(self):
306
  # build the Gradio interface
 
311
  # GUI elements and layout
312
  with gr.Row():
313
  with gr.Column(scale=2):
314
+ with gr.Tab("Regularization contour"):
315
+ self.regularization_contour = gr.Image(value=self.plot_regularization_contour(), container=True)
316
+ with gr.Tab("Data"):
317
+ self.data_image = gr.Image(value=self.plot_data(), container=True)
318
 
319
  with gr.Column(scale=1):
320
  with gr.Tab("Settings"):
 
381
 
382
  # event handlers for GUI elements
383
  loss_type.change(fn=self.update_loss, inputs=loss_type,
384
+ outputs=self.regularization_contour)
385
  regularizer_type.change(fn=self.update_regularizer,
386
+ inputs=regularizer_type, outputs=self.regularization_contour)
387
 
388
  reg_textbox.submit(self.update_reg_levels, inputs=reg_textbox,
389
+ outputs=self.regularization_contour)
390
 
391
  w1_textbox.submit(self.update_w1_range, inputs=w1_textbox,
392
+ outputs=self.regularization_contour)
393
 
394
  w2_textbox.submit(self.update_w2_range, inputs=w2_textbox,
395
+ outputs=self.regularization_contour)
396
 
397
+ slider.change(self.update_resolution, inputs=slider, outputs=self.regularization_contour)
398
 
399
  path_checkbox.change(
400
+ self.update_plot_path, inputs=path_checkbox, outputs=self.regularization_contour
401
  )
402
 
403
  demo.launch()