Spaces:
Sleeping
Sleeping
Commit ·
1f2626d
1
Parent(s): dfbaa58
Split the regularization and data plots into two tabs
Browse files- regularization.py +48 -33
regularization.py
CHANGED
|
@@ -134,7 +134,7 @@ class Regularization:
|
|
| 134 |
|
| 135 |
self.plot_regularization_path = False
|
| 136 |
|
| 137 |
-
def
|
| 138 |
'''
|
| 139 |
'''
|
| 140 |
|
|
@@ -192,23 +192,22 @@ class Regularization:
|
|
| 192 |
print(loss_levels)
|
| 193 |
|
| 194 |
# plot contour plots
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
axs[0].set_ylabel("w2")
|
| 200 |
|
| 201 |
cmap = plt.get_cmap("viridis")
|
| 202 |
N = len(reg_levels)
|
| 203 |
colors = [cmap(i / (N - 1)) for i in range(N)]
|
| 204 |
|
| 205 |
# regularizer contours
|
| 206 |
-
cs1 =
|
| 207 |
-
|
| 208 |
|
| 209 |
# loss contours
|
| 210 |
-
cs2 =
|
| 211 |
-
|
| 212 |
|
| 213 |
# plot path
|
| 214 |
if self.plot_regularization_path:
|
|
@@ -223,7 +222,7 @@ class Regularization:
|
|
| 223 |
path_w.append(stacked[mask][idx])
|
| 224 |
|
| 225 |
path_w = np.array(path_w)
|
| 226 |
-
|
| 227 |
|
| 228 |
# custom legend
|
| 229 |
loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
|
|
@@ -232,13 +231,7 @@ class Regularization:
|
|
| 232 |
if self.plot_regularization_path:
|
| 233 |
path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
|
| 234 |
handles.append(path_line)
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
# plot data points
|
| 238 |
-
axs[1].set_xlabel("X1")
|
| 239 |
-
axs[1].set_ylabel("X2")
|
| 240 |
-
sc = axs[1].scatter(X[:, 0], X[:, 1], c=y, cmap="viridis")
|
| 241 |
-
fig.colorbar(sc, ax=axs[1], label="y")
|
| 242 |
|
| 243 |
# plot solutions
|
| 244 |
#for alpha, w, norm, mse in solutions:
|
|
@@ -253,42 +246,61 @@ class Regularization:
|
|
| 253 |
|
| 254 |
return img
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
def update_loss(self, loss_type):
|
| 257 |
self.loss_type = loss_type
|
| 258 |
self.loss = self.losses[loss_type]
|
| 259 |
-
return self.
|
| 260 |
|
| 261 |
def update_regularizer(self, reg_type):
|
| 262 |
self.reg_type = reg_type
|
| 263 |
self.regularizer = self.regularizers[reg_type]
|
| 264 |
|
| 265 |
-
return self.
|
| 266 |
|
| 267 |
def update_reg_levels(self, reg_levels):
|
| 268 |
self.reg_levels = [float(reg_level) for reg_level in reg_levels.split(",")]
|
| 269 |
|
| 270 |
-
return self.
|
| 271 |
|
| 272 |
def update_w1_range(self, w1_range):
|
| 273 |
self.w1_range = [float(w1) for w1 in w1_range.split(",")]
|
| 274 |
logger.info("Updated w1 range to " + str(self.w1_range))
|
| 275 |
|
| 276 |
-
return self.
|
| 277 |
|
| 278 |
def update_w2_range(self, w2_range):
|
| 279 |
self.w2_range = [float(w2) for w2 in w2_range.split(",")]
|
| 280 |
logger.info("Updated w2 range to " + str(self.w2_range))
|
| 281 |
|
| 282 |
-
return self.
|
| 283 |
|
| 284 |
def update_resolution(self, num_dots):
|
| 285 |
self.num_dots = num_dots
|
| 286 |
logger.info("updated resolution to " + str(num_dots))
|
| 287 |
-
return self.
|
| 288 |
|
| 289 |
def update_plot_path(self, plot_path):
|
| 290 |
self.plot_regularization_path = plot_path
|
| 291 |
-
return self.
|
| 292 |
|
| 293 |
def launch(self):
|
| 294 |
# build the Gradio interface
|
|
@@ -299,7 +311,10 @@ class Regularization:
|
|
| 299 |
# GUI elements and layout
|
| 300 |
with gr.Row():
|
| 301 |
with gr.Column(scale=2):
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
with gr.Column(scale=1):
|
| 305 |
with gr.Tab("Settings"):
|
|
@@ -366,23 +381,23 @@ class Regularization:
|
|
| 366 |
|
| 367 |
# event handlers for GUI elements
|
| 368 |
loss_type.change(fn=self.update_loss, inputs=loss_type,
|
| 369 |
-
outputs=self.
|
| 370 |
regularizer_type.change(fn=self.update_regularizer,
|
| 371 |
-
inputs=regularizer_type, outputs=self.
|
| 372 |
|
| 373 |
reg_textbox.submit(self.update_reg_levels, inputs=reg_textbox,
|
| 374 |
-
outputs=self.
|
| 375 |
|
| 376 |
w1_textbox.submit(self.update_w1_range, inputs=w1_textbox,
|
| 377 |
-
outputs=self.
|
| 378 |
|
| 379 |
w2_textbox.submit(self.update_w2_range, inputs=w2_textbox,
|
| 380 |
-
outputs=self.
|
| 381 |
|
| 382 |
-
slider.change(self.update_resolution, inputs=slider, outputs=self.
|
| 383 |
|
| 384 |
path_checkbox.change(
|
| 385 |
-
self.update_plot_path, inputs=path_checkbox, outputs=self.
|
| 386 |
)
|
| 387 |
|
| 388 |
demo.launch()
|
|
|
|
| 134 |
|
| 135 |
self.plot_regularization_path = False
|
| 136 |
|
| 137 |
+
def plot_regularization_contour(self):
|
| 138 |
'''
|
| 139 |
'''
|
| 140 |
|
|
|
|
| 192 |
print(loss_levels)
|
| 193 |
|
| 194 |
# plot contour plots
|
| 195 |
+
fig, ax = plt.subplots(figsize=(8, 8))
|
| 196 |
+
ax.set_title("")
|
| 197 |
+
ax.set_xlabel("w1")
|
| 198 |
+
ax.set_ylabel("w2")
|
|
|
|
| 199 |
|
| 200 |
cmap = plt.get_cmap("viridis")
|
| 201 |
N = len(reg_levels)
|
| 202 |
colors = [cmap(i / (N - 1)) for i in range(N)]
|
| 203 |
|
| 204 |
# regularizer contours
|
| 205 |
+
cs1 = ax.contour(W1, W2, regs, levels=reg_levels, colors=colors, linestyles="dashed")
|
| 206 |
+
ax.clabel(cs1, inline=True, fontsize=8) # show contour levels
|
| 207 |
|
| 208 |
# loss contours
|
| 209 |
+
cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
|
| 210 |
+
ax.clabel(cs2, inline=True, fontsize=8)
|
| 211 |
|
| 212 |
# plot path
|
| 213 |
if self.plot_regularization_path:
|
|
|
|
| 222 |
path_w.append(stacked[mask][idx])
|
| 223 |
|
| 224 |
path_w = np.array(path_w)
|
| 225 |
+
ax.plot(path_w[:, 0], path_w[:, 1], "r-")
|
| 226 |
|
| 227 |
# custom legend
|
| 228 |
loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
|
|
|
|
| 231 |
if self.plot_regularization_path:
|
| 232 |
path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
|
| 233 |
handles.append(path_line)
|
| 234 |
+
ax.legend(handles=handles)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
# plot solutions
|
| 237 |
#for alpha, w, norm, mse in solutions:
|
|
|
|
| 246 |
|
| 247 |
return img
|
| 248 |
|
| 249 |
+
def plot_data(self):
|
| 250 |
+
# make sure the data is the same as the one used in plot_regularization_contour
|
| 251 |
+
X, y = make_regression(n_samples=200, n_features=2, noise=15, random_state=0)
|
| 252 |
+
fig, ax = plt.subplots(figsize=(8, 8))
|
| 253 |
+
|
| 254 |
+
# plot data points
|
| 255 |
+
ax.set_xlabel("X1")
|
| 256 |
+
ax.set_ylabel("X2")
|
| 257 |
+
sc = ax.scatter(X[:, 0], X[:, 1], c=y, cmap="viridis")
|
| 258 |
+
fig.colorbar(sc, ax=ax, label="y")
|
| 259 |
+
|
| 260 |
+
buf = io.BytesIO()
|
| 261 |
+
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
| 262 |
+
plt.close(fig)
|
| 263 |
+
buf.seek(0)
|
| 264 |
+
img = Image.open(buf)
|
| 265 |
+
|
| 266 |
+
return img
|
| 267 |
+
|
| 268 |
def update_loss(self, loss_type):
|
| 269 |
self.loss_type = loss_type
|
| 270 |
self.loss = self.losses[loss_type]
|
| 271 |
+
return self.plot_regularization_contour()
|
| 272 |
|
| 273 |
def update_regularizer(self, reg_type):
|
| 274 |
self.reg_type = reg_type
|
| 275 |
self.regularizer = self.regularizers[reg_type]
|
| 276 |
|
| 277 |
+
return self.plot_regularization_contour()
|
| 278 |
|
| 279 |
def update_reg_levels(self, reg_levels):
|
| 280 |
self.reg_levels = [float(reg_level) for reg_level in reg_levels.split(",")]
|
| 281 |
|
| 282 |
+
return self.plot_regularization_contour()
|
| 283 |
|
| 284 |
def update_w1_range(self, w1_range):
|
| 285 |
self.w1_range = [float(w1) for w1 in w1_range.split(",")]
|
| 286 |
logger.info("Updated w1 range to " + str(self.w1_range))
|
| 287 |
|
| 288 |
+
return self.plot_regularization_contour()
|
| 289 |
|
| 290 |
def update_w2_range(self, w2_range):
|
| 291 |
self.w2_range = [float(w2) for w2 in w2_range.split(",")]
|
| 292 |
logger.info("Updated w2 range to " + str(self.w2_range))
|
| 293 |
|
| 294 |
+
return self.plot_regularization_contour()
|
| 295 |
|
| 296 |
def update_resolution(self, num_dots):
|
| 297 |
self.num_dots = num_dots
|
| 298 |
logger.info("updated resolution to " + str(num_dots))
|
| 299 |
+
return self.plot_regularization_contour()
|
| 300 |
|
| 301 |
def update_plot_path(self, plot_path):
|
| 302 |
self.plot_regularization_path = plot_path
|
| 303 |
+
return self.plot_regularization_contour()
|
| 304 |
|
| 305 |
def launch(self):
|
| 306 |
# build the Gradio interface
|
|
|
|
| 311 |
# GUI elements and layout
|
| 312 |
with gr.Row():
|
| 313 |
with gr.Column(scale=2):
|
| 314 |
+
with gr.Tab("Regularization contour"):
|
| 315 |
+
self.regularization_contour = gr.Image(value=self.plot_regularization_contour(), container=True)
|
| 316 |
+
with gr.Tab("Data"):
|
| 317 |
+
self.data_image = gr.Image(value=self.plot_data(), container=True)
|
| 318 |
|
| 319 |
with gr.Column(scale=1):
|
| 320 |
with gr.Tab("Settings"):
|
|
|
|
| 381 |
|
| 382 |
# event handlers for GUI elements
|
| 383 |
loss_type.change(fn=self.update_loss, inputs=loss_type,
|
| 384 |
+
outputs=self.regularization_contour)
|
| 385 |
regularizer_type.change(fn=self.update_regularizer,
|
| 386 |
+
inputs=regularizer_type, outputs=self.regularization_contour)
|
| 387 |
|
| 388 |
reg_textbox.submit(self.update_reg_levels, inputs=reg_textbox,
|
| 389 |
+
outputs=self.regularization_contour)
|
| 390 |
|
| 391 |
w1_textbox.submit(self.update_w1_range, inputs=w1_textbox,
|
| 392 |
+
outputs=self.regularization_contour)
|
| 393 |
|
| 394 |
w2_textbox.submit(self.update_w2_range, inputs=w2_textbox,
|
| 395 |
+
outputs=self.regularization_contour)
|
| 396 |
|
| 397 |
+
slider.change(self.update_resolution, inputs=slider, outputs=self.regularization_contour)
|
| 398 |
|
| 399 |
path_checkbox.change(
|
| 400 |
+
self.update_plot_path, inputs=path_checkbox, outputs=self.regularization_contour
|
| 401 |
)
|
| 402 |
|
| 403 |
demo.launch()
|