Spaces:
Sleeping
Sleeping
- update layout
Browse files- allow adjusting parameter ranges and grid resolution
- regularization.py +56 -19
regularization.py
CHANGED
|
@@ -37,7 +37,7 @@ def min_corresponding_entries(W1, W2, w1, tol=0.1):
|
|
| 37 |
values = W2[mask]
|
| 38 |
|
| 39 |
if values.size == 0:
|
| 40 |
-
raise ValueError("No entries in W1
|
| 41 |
|
| 42 |
return np.min(values)
|
| 43 |
|
|
@@ -103,6 +103,9 @@ class Regularization:
|
|
| 103 |
self.Regularizer = self.Regularizers[self.reg_type]
|
| 104 |
|
| 105 |
self.reg_levels = [10, 20, 30]
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
def plot(self):
|
| 108 |
'''
|
|
@@ -135,8 +138,8 @@ class Regularization:
|
|
| 135 |
#loss_levels = [sol[3] for sol in solutions]
|
| 136 |
|
| 137 |
# build grid in parameter space
|
| 138 |
-
w1 = np.linspace(
|
| 139 |
-
w2 = np.linspace(
|
| 140 |
W1, W2 = np.meshgrid(w1, w2)
|
| 141 |
|
| 142 |
# compute regularizer surface
|
|
@@ -173,11 +176,11 @@ class Regularization:
|
|
| 173 |
|
| 174 |
# regularizer contours
|
| 175 |
cs1 = ax.contour(W1, W2, regs, levels=reg_levels, colors=colors)
|
| 176 |
-
|
| 177 |
|
| 178 |
# loss contours
|
| 179 |
cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
|
| 180 |
-
|
| 181 |
|
| 182 |
# plot solutions
|
| 183 |
#for alpha, w, norm, mse in solutions:
|
|
@@ -208,8 +211,22 @@ class Regularization:
|
|
| 208 |
|
| 209 |
return self.plot()
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
def update_resolution(self, num_dots):
|
| 212 |
self.num_dots = num_dots
|
|
|
|
| 213 |
return self.plot()
|
| 214 |
|
| 215 |
def launch(self):
|
|
@@ -238,18 +255,32 @@ class Regularization:
|
|
| 238 |
value='l2',
|
| 239 |
visible=True)
|
| 240 |
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
with gr.Tab("Export"):
|
| 255 |
# use hidden download button to generate files on the fly
|
|
@@ -264,8 +295,6 @@ class Regularization:
|
|
| 264 |
btn_export_code = gr.Button('Code')
|
| 265 |
btn_export_code_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_code_hidden", elem_classes="hidden-button")
|
| 266 |
|
| 267 |
-
with gr.Tab("Options"):
|
| 268 |
-
slider = gr.Slider(minimum=100, maximum=1000, value=100, step=1, label="Resolution (#points)")
|
| 269 |
|
| 270 |
with gr.Tab("Usage"):
|
| 271 |
gr.Markdown(''.join(open('usage.md', 'r').readlines()))
|
|
@@ -280,6 +309,14 @@ class Regularization:
|
|
| 280 |
reg_textbox.submit(self.update_reg_levels, inputs=reg_textbox,
|
| 281 |
outputs=self.data_image)
|
| 282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
demo.launch()
|
| 284 |
|
| 285 |
visualizer = Regularization(width=1200, height=900)
|
|
|
|
| 37 |
values = W2[mask]
|
| 38 |
|
| 39 |
if values.size == 0:
|
| 40 |
+
raise ValueError("No entries in W1 less than equal to w1")
|
| 41 |
|
| 42 |
return np.min(values)
|
| 43 |
|
|
|
|
| 103 |
self.Regularizer = self.Regularizers[self.reg_type]
|
| 104 |
|
| 105 |
self.reg_levels = [10, 20, 30]
|
| 106 |
+
self.w1_range = (-100, 100)
|
| 107 |
+
self.w2_range = (-100, 100)
|
| 108 |
+
self.num_dots = 100
|
| 109 |
|
| 110 |
def plot(self):
|
| 111 |
'''
|
|
|
|
| 138 |
#loss_levels = [sol[3] for sol in solutions]
|
| 139 |
|
| 140 |
# build grid in parameter space
|
| 141 |
+
w1 = np.linspace(self.w1_range[0], self.w1_range[1], self.num_dots)
|
| 142 |
+
w2 = np.linspace(self.w2_range[0], self.w2_range[1], self.num_dots)
|
| 143 |
W1, W2 = np.meshgrid(w1, w2)
|
| 144 |
|
| 145 |
# compute regularizer surface
|
|
|
|
| 176 |
|
| 177 |
# regularizer contours
|
| 178 |
cs1 = ax.contour(W1, W2, regs, levels=reg_levels, colors=colors)
|
| 179 |
+
ax.clabel(cs1, inline=True, fontsize=8) # show contour levels
|
| 180 |
|
| 181 |
# loss contours
|
| 182 |
cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
|
| 183 |
+
ax.clabel(cs2, inline=True, fontsize=8)
|
| 184 |
|
| 185 |
# plot solutions
|
| 186 |
#for alpha, w, norm, mse in solutions:
|
|
|
|
| 211 |
|
| 212 |
return self.plot()
|
| 213 |
|
| 214 |
+
def update_w1_range(self, w1_range):
|
| 215 |
+
self.w1_range = [float(w1) for w1 in w1_range.split(",")]
|
| 216 |
+
logger.info("Updated w1 range to " + str(self.w1_range))
|
| 217 |
+
|
| 218 |
+
return self.plot()
|
| 219 |
+
|
| 220 |
+
def update_w2_range(self, w2_range):
|
| 221 |
+
self.w2_range = [float(w2) for w2 in w2_range.split(",")]
|
| 222 |
+
logger.info("Updated w2 range to " + str(self.w2_range))
|
| 223 |
+
|
| 224 |
+
return self.plot()
|
| 225 |
+
|
| 226 |
+
|
| 227 |
def update_resolution(self, num_dots):
|
| 228 |
self.num_dots = num_dots
|
| 229 |
+
logger.info("updated resolution to " + str(num_dots))
|
| 230 |
return self.plot()
|
| 231 |
|
| 232 |
def launch(self):
|
|
|
|
| 255 |
value='l2',
|
| 256 |
visible=True)
|
| 257 |
|
| 258 |
+
with gr.Row():
|
| 259 |
+
# regularizer type
|
| 260 |
+
regularizer_type = gr.Dropdown(choices=['l1', 'l2', 'elastic-net'],
|
| 261 |
+
label='Regularizer type',
|
| 262 |
+
value='l2',
|
| 263 |
+
visible=True)
|
| 264 |
+
|
| 265 |
+
# regularization strength
|
| 266 |
+
#reg_textbox = gr.Textbox(label="Regularization constants")
|
| 267 |
+
reg_textbox = gr.Textbox(label="Regularizer levels",
|
| 268 |
+
value="10, 20, 30",
|
| 269 |
+
interactive=True)
|
| 270 |
+
self.reg_textbox = reg_textbox
|
| 271 |
+
|
| 272 |
+
with gr.Row():
|
| 273 |
+
# parameter value ranges
|
| 274 |
+
w1_textbox = gr.Textbox(label="w1 range",
|
| 275 |
+
value="-100, 100",
|
| 276 |
+
interactive=True)
|
| 277 |
+
|
| 278 |
+
w2_textbox = gr.Textbox(label="w2 range",
|
| 279 |
+
value="-100, 100",
|
| 280 |
+
interactive=True)
|
| 281 |
+
|
| 282 |
+
# resolution
|
| 283 |
+
slider = gr.Slider(minimum=100, maximum=1000, value=100, step=1, label="Resolution (#points)")
|
| 284 |
|
| 285 |
with gr.Tab("Export"):
|
| 286 |
# use hidden download button to generate files on the fly
|
|
|
|
| 295 |
btn_export_code = gr.Button('Code')
|
| 296 |
btn_export_code_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_code_hidden", elem_classes="hidden-button")
|
| 297 |
|
|
|
|
|
|
|
| 298 |
|
| 299 |
with gr.Tab("Usage"):
|
| 300 |
gr.Markdown(''.join(open('usage.md', 'r').readlines()))
|
|
|
|
| 309 |
reg_textbox.submit(self.update_reg_levels, inputs=reg_textbox,
|
| 310 |
outputs=self.data_image)
|
| 311 |
|
| 312 |
+
w1_textbox.submit(self.update_w1_range, inputs=w1_textbox,
|
| 313 |
+
outputs=self.data_image)
|
| 314 |
+
|
| 315 |
+
w2_textbox.submit(self.update_w2_range, inputs=w2_textbox,
|
| 316 |
+
outputs=self.data_image)
|
| 317 |
+
|
| 318 |
+
slider.change(self.update_resolution, inputs=slider, outputs=self.data_image)
|
| 319 |
+
|
| 320 |
demo.launch()
|
| 321 |
|
| 322 |
visualizer = Regularization(width=1200, height=900)
|