Spaces:
Sleeping
Sleeping
Commit ·
9e75a75
1
Parent(s): ba3e474
Add regularization path
Browse files- regularization.py +33 -2
regularization.py
CHANGED
|
@@ -51,6 +51,8 @@ class Regularization:
|
|
| 51 |
self.canvas_width = width
|
| 52 |
self.canvas_height = height
|
| 53 |
|
|
|
|
|
|
|
| 54 |
self.css ="""
|
| 55 |
#my-button {
|
| 56 |
height: 30px;
|
|
@@ -183,10 +185,29 @@ class Regularization:
|
|
| 183 |
cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
|
| 184 |
ax.clabel(cs2, inline=True, fontsize=8)
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
# custom legend
|
| 187 |
loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
|
| 188 |
reg_line = mlines.Line2D([], [], color='black', linestyle='--', label='regularization')
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
# plot solutions
|
| 192 |
#for alpha, w, norm, mse in solutions:
|
|
@@ -229,12 +250,15 @@ class Regularization:
|
|
| 229 |
|
| 230 |
return self.plot()
|
| 231 |
|
| 232 |
-
|
| 233 |
def update_resolution(self, num_dots):
|
| 234 |
self.num_dots = num_dots
|
| 235 |
logger.info("updated resolution to " + str(num_dots))
|
| 236 |
return self.plot()
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def launch(self):
|
| 239 |
# build the Gradio interface
|
| 240 |
with gr.Blocks(css=self.css) as demo:
|
|
@@ -288,6 +312,9 @@ class Regularization:
|
|
| 288 |
# resolution
|
| 289 |
slider = gr.Slider(minimum=100, maximum=1000, value=100, step=1, label="Resolution (#points)")
|
| 290 |
|
|
|
|
|
|
|
|
|
|
| 291 |
with gr.Tab("Export"):
|
| 292 |
# use hidden download button to generate files on the fly
|
| 293 |
# https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
|
|
@@ -323,6 +350,10 @@ class Regularization:
|
|
| 323 |
|
| 324 |
slider.change(self.update_resolution, inputs=slider, outputs=self.data_image)
|
| 325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
demo.launch()
|
| 327 |
|
| 328 |
visualizer = Regularization(width=1200, height=900)
|
|
|
|
| 51 |
self.canvas_width = width
|
| 52 |
self.canvas_height = height
|
| 53 |
|
| 54 |
+
self.plot_regularization_path = False
|
| 55 |
+
|
| 56 |
self.css ="""
|
| 57 |
#my-button {
|
| 58 |
height: 30px;
|
|
|
|
| 185 |
cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
|
| 186 |
ax.clabel(cs2, inline=True, fontsize=8)
|
| 187 |
|
| 188 |
+
# plot path
|
| 189 |
+
if self.plot_regularization_path:
|
| 190 |
+
min_loss_reg = regs.ravel()[np.argmin(losses)]
|
| 191 |
+
path_reg_levels = np.linspace(0, min_loss_reg, 20)
|
| 192 |
+
path_w = []
|
| 193 |
+
for reg_level in path_reg_levels:
|
| 194 |
+
mask = regs <= reg_level
|
| 195 |
+
if np.sum(mask) == 0:
|
| 196 |
+
continue
|
| 197 |
+
idx = np.argmin(losses[mask])
|
| 198 |
+
path_w.append(stacked[mask][idx])
|
| 199 |
+
|
| 200 |
+
path_w = np.array(path_w)
|
| 201 |
+
ax.plot(path_w[:, 0], path_w[:, 1], "r-")
|
| 202 |
+
|
| 203 |
# custom legend
|
| 204 |
loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
|
| 205 |
reg_line = mlines.Line2D([], [], color='black', linestyle='--', label='regularization')
|
| 206 |
+
handles = [loss_line, reg_line]
|
| 207 |
+
if self.plot_regularization_path:
|
| 208 |
+
path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
|
| 209 |
+
handles.append(path_line)
|
| 210 |
+
ax.legend(handles=handles)
|
| 211 |
|
| 212 |
# plot solutions
|
| 213 |
#for alpha, w, norm, mse in solutions:
|
|
|
|
| 250 |
|
| 251 |
return self.plot()
|
| 252 |
|
|
|
|
| 253 |
def update_resolution(self, num_dots):
|
| 254 |
self.num_dots = num_dots
|
| 255 |
logger.info("updated resolution to " + str(num_dots))
|
| 256 |
return self.plot()
|
| 257 |
|
| 258 |
+
def update_plot_path(self, plot_path):
|
| 259 |
+
self.plot_regularization_path = plot_path
|
| 260 |
+
return self.plot()
|
| 261 |
+
|
| 262 |
def launch(self):
|
| 263 |
# build the Gradio interface
|
| 264 |
with gr.Blocks(css=self.css) as demo:
|
|
|
|
| 312 |
# resolution
|
| 313 |
slider = gr.Slider(minimum=100, maximum=1000, value=100, step=1, label="Resolution (#points)")
|
| 314 |
|
| 315 |
+
# plot path
|
| 316 |
+
path_checkbox = gr.Checkbox(label="Show regularization path", value=False)
|
| 317 |
+
|
| 318 |
with gr.Tab("Export"):
|
| 319 |
# use hidden download button to generate files on the fly
|
| 320 |
# https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
|
|
|
|
| 350 |
|
| 351 |
slider.change(self.update_resolution, inputs=slider, outputs=self.data_image)
|
| 352 |
|
| 353 |
+
path_checkbox.change(
|
| 354 |
+
self.update_plot_path, inputs=path_checkbox, outputs=self.data_image
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
demo.launch()
|
| 358 |
|
| 359 |
visualizer = Regularization(width=1200, height=900)
|