Spaces:
Sleeping
Sleeping
Joel Woodfield commited on
Commit ·
6034fcd
1
Parent(s): 9e75a75
Add data plot
Browse files- regularization.py +20 -15
regularization.py
CHANGED
|
@@ -51,8 +51,6 @@ class Regularization:
|
|
| 51 |
self.canvas_width = width
|
| 52 |
self.canvas_height = height
|
| 53 |
|
| 54 |
-
self.plot_regularization_path = False
|
| 55 |
-
|
| 56 |
self.css ="""
|
| 57 |
#my-button {
|
| 58 |
height: 30px;
|
|
@@ -110,6 +108,8 @@ class Regularization:
|
|
| 110 |
self.w2_range = (-100, 100)
|
| 111 |
self.num_dots = 100
|
| 112 |
|
|
|
|
|
|
|
| 113 |
def plot(self):
|
| 114 |
'''
|
| 115 |
'''
|
|
@@ -166,24 +166,23 @@ class Regularization:
|
|
| 166 |
print(loss_levels)
|
| 167 |
|
| 168 |
# plot contour plots
|
| 169 |
-
fig = plt.figure(figsize=(5, 5))
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
|
| 176 |
cmap = plt.get_cmap("viridis")
|
| 177 |
N = len(reg_levels)
|
| 178 |
colors = [cmap(i / (N - 1)) for i in range(N)]
|
| 179 |
|
| 180 |
# regularizer contours
|
| 181 |
-
cs1 =
|
| 182 |
-
|
| 183 |
|
| 184 |
# loss contours
|
| 185 |
-
cs2 =
|
| 186 |
-
|
| 187 |
|
| 188 |
# plot path
|
| 189 |
if self.plot_regularization_path:
|
|
@@ -198,7 +197,7 @@ class Regularization:
|
|
| 198 |
path_w.append(stacked[mask][idx])
|
| 199 |
|
| 200 |
path_w = np.array(path_w)
|
| 201 |
-
|
| 202 |
|
| 203 |
# custom legend
|
| 204 |
loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
|
|
@@ -207,7 +206,13 @@ class Regularization:
|
|
| 207 |
if self.plot_regularization_path:
|
| 208 |
path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
|
| 209 |
handles.append(path_line)
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
# plot solutions
|
| 213 |
#for alpha, w, norm, mse in solutions:
|
|
@@ -215,7 +220,7 @@ class Regularization:
|
|
| 215 |
##ax.text(w[0], w[1], f"α={alpha}", fontsize=8)
|
| 216 |
|
| 217 |
buf = io.BytesIO()
|
| 218 |
-
|
| 219 |
plt.close(fig)
|
| 220 |
buf.seek(0)
|
| 221 |
img = Image.open(buf)
|
|
|
|
| 51 |
self.canvas_width = width
|
| 52 |
self.canvas_height = height
|
| 53 |
|
|
|
|
|
|
|
| 54 |
self.css ="""
|
| 55 |
#my-button {
|
| 56 |
height: 30px;
|
|
|
|
| 108 |
self.w2_range = (-100, 100)
|
| 109 |
self.num_dots = 100
|
| 110 |
|
| 111 |
+
self.plot_regularization_path = False
|
| 112 |
+
|
| 113 |
def plot(self):
|
| 114 |
'''
|
| 115 |
'''
|
|
|
|
| 166 |
print(loss_levels)
|
| 167 |
|
| 168 |
# plot contour plots
|
| 169 |
+
# fig = plt.figure(figsize=(5, 5))
|
| 170 |
+
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
|
| 171 |
+
axs[0].set_title("")
|
| 172 |
+
axs[0].set_xlabel("w1")
|
| 173 |
+
axs[0].set_ylabel("w2")
|
|
|
|
| 174 |
|
| 175 |
cmap = plt.get_cmap("viridis")
|
| 176 |
N = len(reg_levels)
|
| 177 |
colors = [cmap(i / (N - 1)) for i in range(N)]
|
| 178 |
|
| 179 |
# regularizer contours
|
| 180 |
+
cs1 = axs[0].contour(W1, W2, regs, levels=reg_levels, colors=colors, linestyles="dashed")
|
| 181 |
+
axs[0].clabel(cs1, inline=True, fontsize=8) # show contour levels
|
| 182 |
|
| 183 |
# loss contours
|
| 184 |
+
cs2 = axs[0].contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
|
| 185 |
+
axs[0].clabel(cs2, inline=True, fontsize=8)
|
| 186 |
|
| 187 |
# plot path
|
| 188 |
if self.plot_regularization_path:
|
|
|
|
| 197 |
path_w.append(stacked[mask][idx])
|
| 198 |
|
| 199 |
path_w = np.array(path_w)
|
| 200 |
+
axs[0].plot(path_w[:, 0], path_w[:, 1], "r-")
|
| 201 |
|
| 202 |
# custom legend
|
| 203 |
loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
|
|
|
|
| 206 |
if self.plot_regularization_path:
|
| 207 |
path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
|
| 208 |
handles.append(path_line)
|
| 209 |
+
axs[0].legend(handles=handles)
|
| 210 |
+
|
| 211 |
+
# plot data points
|
| 212 |
+
axs[1].set_xlabel("X1")
|
| 213 |
+
axs[1].set_ylabel("X2")
|
| 214 |
+
sc = axs[1].scatter(X[:, 0], X[:, 1], c=y, cmap="viridis")
|
| 215 |
+
fig.colorbar(sc, ax=axs[1], label="y")
|
| 216 |
|
| 217 |
# plot solutions
|
| 218 |
#for alpha, w, norm, mse in solutions:
|
|
|
|
| 220 |
##ax.text(w[0], w[1], f"α={alpha}", fontsize=8)
|
| 221 |
|
| 222 |
buf = io.BytesIO()
|
| 223 |
+
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
| 224 |
plt.close(fig)
|
| 225 |
buf.seek(0)
|
| 226 |
img = Image.open(buf)
|