nanye commited on
Commit
19d9a90
·
1 Parent(s): 26bd69f

- update layout

Browse files

- allow adjusting parameter ranges and grid resolution

Files changed (1) hide show
  1. regularization.py +56 -19
regularization.py CHANGED
@@ -37,7 +37,7 @@ def min_corresponding_entries(W1, W2, w1, tol=0.1):
37
  values = W2[mask]
38
 
39
  if values.size == 0:
40
- raise ValueError("No entries in W1 approximately equal to w1")
41
 
42
  return np.min(values)
43
 
@@ -103,6 +103,9 @@ class Regularization:
103
  self.Regularizer = self.Regularizers[self.reg_type]
104
 
105
  self.reg_levels = [10, 20, 30]
 
 
 
106
 
107
  def plot(self):
108
  '''
@@ -135,8 +138,8 @@ class Regularization:
135
  #loss_levels = [sol[3] for sol in solutions]
136
 
137
  # build grid in parameter space
138
- w1 = np.linspace(-100, 100, 400)
139
- w2 = np.linspace(-100, 100, 400)
140
  W1, W2 = np.meshgrid(w1, w2)
141
 
142
  # compute regularizer surface
@@ -173,11 +176,11 @@ class Regularization:
173
 
174
  # regularizer contours
175
  cs1 = ax.contour(W1, W2, regs, levels=reg_levels, colors=colors)
176
- #ax.clabel(cs1, inline=True, fontsize=8) # show contour levels
177
 
178
  # loss contours
179
  cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
180
- #ax.clabel(cs2, inline=True, fontsize=8)
181
 
182
  # plot solutions
183
  #for alpha, w, norm, mse in solutions:
@@ -208,8 +211,22 @@ class Regularization:
208
 
209
  return self.plot()
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  def update_resolution(self, num_dots):
212
  self.num_dots = num_dots
 
213
  return self.plot()
214
 
215
  def launch(self):
@@ -238,18 +255,32 @@ class Regularization:
238
  value='l2',
239
  visible=True)
240
 
241
- # regularizer type
242
- regularizer_type = gr.Dropdown(choices=['l1', 'l2', 'elastic-net'],
243
- label='Regularizer type',
244
- value='l2',
245
- visible=True)
246
-
247
- # regularization strength
248
- #reg_textbox = gr.Textbox(label="Regularization constants")
249
- reg_textbox = gr.Textbox(label="Regularizer levels",
250
- value="10, 20, 30",
251
- interactive=True)
252
- self.reg_textbox = reg_textbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  with gr.Tab("Export"):
255
  # use hidden download button to generate files on the fly
@@ -264,8 +295,6 @@ class Regularization:
264
  btn_export_code = gr.Button('Code')
265
  btn_export_code_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_code_hidden", elem_classes="hidden-button")
266
 
267
- with gr.Tab("Options"):
268
- slider = gr.Slider(minimum=100, maximum=1000, value=100, step=1, label="Resolution (#points)")
269
 
270
  with gr.Tab("Usage"):
271
  gr.Markdown(''.join(open('usage.md', 'r').readlines()))
@@ -280,6 +309,14 @@ class Regularization:
280
  reg_textbox.submit(self.update_reg_levels, inputs=reg_textbox,
281
  outputs=self.data_image)
282
 
 
 
 
 
 
 
 
 
283
  demo.launch()
284
 
285
  visualizer = Regularization(width=1200, height=900)
 
37
  values = W2[mask]
38
 
39
  if values.size == 0:
40
+ raise ValueError("No entries in W1 less than equal to w1")
41
 
42
  return np.min(values)
43
 
 
103
  self.Regularizer = self.Regularizers[self.reg_type]
104
 
105
  self.reg_levels = [10, 20, 30]
106
+ self.w1_range = (-100, 100)
107
+ self.w2_range = (-100, 100)
108
+ self.num_dots = 100
109
 
110
  def plot(self):
111
  '''
 
138
  #loss_levels = [sol[3] for sol in solutions]
139
 
140
  # build grid in parameter space
141
+ w1 = np.linspace(self.w1_range[0], self.w1_range[1], self.num_dots)
142
+ w2 = np.linspace(self.w2_range[0], self.w2_range[1], self.num_dots)
143
  W1, W2 = np.meshgrid(w1, w2)
144
 
145
  # compute regularizer surface
 
176
 
177
  # regularizer contours
178
  cs1 = ax.contour(W1, W2, regs, levels=reg_levels, colors=colors)
179
+ ax.clabel(cs1, inline=True, fontsize=8) # show contour levels
180
 
181
  # loss contours
182
  cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
183
+ ax.clabel(cs2, inline=True, fontsize=8)
184
 
185
  # plot solutions
186
  #for alpha, w, norm, mse in solutions:
 
211
 
212
  return self.plot()
213
 
214
+ def update_w1_range(self, w1_range):
215
+ self.w1_range = [float(w1) for w1 in w1_range.split(",")]
216
+ logger.info("Updated w1 range to " + str(self.w1_range))
217
+
218
+ return self.plot()
219
+
220
+ def update_w2_range(self, w2_range):
221
+ self.w2_range = [float(w2) for w2 in w2_range.split(",")]
222
+ logger.info("Updated w2 range to " + str(self.w2_range))
223
+
224
+ return self.plot()
225
+
226
+
227
  def update_resolution(self, num_dots):
228
  self.num_dots = num_dots
229
+ logger.info("updated resolution to " + str(num_dots))
230
  return self.plot()
231
 
232
  def launch(self):
 
255
  value='l2',
256
  visible=True)
257
 
258
+ with gr.Row():
259
+ # regularizer type
260
+ regularizer_type = gr.Dropdown(choices=['l1', 'l2', 'elastic-net'],
261
+ label='Regularizer type',
262
+ value='l2',
263
+ visible=True)
264
+
265
+ # regularization strength
266
+ #reg_textbox = gr.Textbox(label="Regularization constants")
267
+ reg_textbox = gr.Textbox(label="Regularizer levels",
268
+ value="10, 20, 30",
269
+ interactive=True)
270
+ self.reg_textbox = reg_textbox
271
+
272
+ with gr.Row():
273
+ # parameter value ranges
274
+ w1_textbox = gr.Textbox(label="w1 range",
275
+ value="-100, 100",
276
+ interactive=True)
277
+
278
+ w2_textbox = gr.Textbox(label="w2 range",
279
+ value="-100, 100",
280
+ interactive=True)
281
+
282
+ # resolution
283
+ slider = gr.Slider(minimum=100, maximum=1000, value=100, step=1, label="Resolution (#points)")
284
 
285
  with gr.Tab("Export"):
286
  # use hidden download button to generate files on the fly
 
295
  btn_export_code = gr.Button('Code')
296
  btn_export_code_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_code_hidden", elem_classes="hidden-button")
297
 
 
 
298
 
299
  with gr.Tab("Usage"):
300
  gr.Markdown(''.join(open('usage.md', 'r').readlines()))
 
309
  reg_textbox.submit(self.update_reg_levels, inputs=reg_textbox,
310
  outputs=self.data_image)
311
 
312
+ w1_textbox.submit(self.update_w1_range, inputs=w1_textbox,
313
+ outputs=self.data_image)
314
+
315
+ w2_textbox.submit(self.update_w2_range, inputs=w2_textbox,
316
+ outputs=self.data_image)
317
+
318
+ slider.change(self.update_resolution, inputs=slider, outputs=self.data_image)
319
+
320
  demo.launch()
321
 
322
  visualizer = Regularization(width=1200, height=900)