Joel Woodfield commited on
Commit
6034fcd
·
1 Parent(s): 9e75a75

Add data plot

Browse files
Files changed (1) hide show
  1. regularization.py +20 -15
regularization.py CHANGED
@@ -51,8 +51,6 @@ class Regularization:
51
  self.canvas_width = width
52
  self.canvas_height = height
53
 
54
- self.plot_regularization_path = False
55
-
56
  self.css ="""
57
  #my-button {
58
  height: 30px;
@@ -110,6 +108,8 @@ class Regularization:
110
  self.w2_range = (-100, 100)
111
  self.num_dots = 100
112
 
 
 
113
  def plot(self):
114
  '''
115
  '''
@@ -166,24 +166,23 @@ class Regularization:
166
  print(loss_levels)
167
 
168
  # plot contour plots
169
- fig = plt.figure(figsize=(5, 5))
170
- ax = plt.gca()
171
- ax.set_title("")
172
- ax.set_xlabel("w1")
173
- ax.set_ylabel("w2")
174
-
175
 
176
  cmap = plt.get_cmap("viridis")
177
  N = len(reg_levels)
178
  colors = [cmap(i / (N - 1)) for i in range(N)]
179
 
180
  # regularizer contours
181
- cs1 = ax.contour(W1, W2, regs, levels=reg_levels, colors=colors, linestyles="dashed")
182
- ax.clabel(cs1, inline=True, fontsize=8) # show contour levels
183
 
184
  # loss contours
185
- cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
186
- ax.clabel(cs2, inline=True, fontsize=8)
187
 
188
  # plot path
189
  if self.plot_regularization_path:
@@ -198,7 +197,7 @@ class Regularization:
198
  path_w.append(stacked[mask][idx])
199
 
200
  path_w = np.array(path_w)
201
- ax.plot(path_w[:, 0], path_w[:, 1], "r-")
202
 
203
  # custom legend
204
  loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
@@ -207,7 +206,13 @@ class Regularization:
207
  if self.plot_regularization_path:
208
  path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
209
  handles.append(path_line)
210
- ax.legend(handles=handles)
 
 
 
 
 
 
211
 
212
  # plot solutions
213
  #for alpha, w, norm, mse in solutions:
@@ -215,7 +220,7 @@ class Regularization:
215
  ##ax.text(w[0], w[1], f"α={alpha}", fontsize=8)
216
 
217
  buf = io.BytesIO()
218
- ax.figure.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
219
  plt.close(fig)
220
  buf.seek(0)
221
  img = Image.open(buf)
 
51
  self.canvas_width = width
52
  self.canvas_height = height
53
 
 
 
54
  self.css ="""
55
  #my-button {
56
  height: 30px;
 
108
  self.w2_range = (-100, 100)
109
  self.num_dots = 100
110
 
111
+ self.plot_regularization_path = False
112
+
113
  def plot(self):
114
  '''
115
  '''
 
166
  print(loss_levels)
167
 
168
  # plot contour plots
169
+ # fig = plt.figure(figsize=(5, 5))
170
+ fig, axs = plt.subplots(1, 2, figsize=(10, 5))
171
+ axs[0].set_title("")
172
+ axs[0].set_xlabel("w1")
173
+ axs[0].set_ylabel("w2")
 
174
 
175
  cmap = plt.get_cmap("viridis")
176
  N = len(reg_levels)
177
  colors = [cmap(i / (N - 1)) for i in range(N)]
178
 
179
  # regularizer contours
180
+ cs1 = axs[0].contour(W1, W2, regs, levels=reg_levels, colors=colors, linestyles="dashed")
181
+ axs[0].clabel(cs1, inline=True, fontsize=8) # show contour levels
182
 
183
  # loss contours
184
+ cs2 = axs[0].contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
185
+ axs[0].clabel(cs2, inline=True, fontsize=8)
186
 
187
  # plot path
188
  if self.plot_regularization_path:
 
197
  path_w.append(stacked[mask][idx])
198
 
199
  path_w = np.array(path_w)
200
+ axs[0].plot(path_w[:, 0], path_w[:, 1], "r-")
201
 
202
  # custom legend
203
  loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
 
206
  if self.plot_regularization_path:
207
  path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
208
  handles.append(path_line)
209
+ axs[0].legend(handles=handles)
210
+
211
+ # plot data points
212
+ axs[1].set_xlabel("X1")
213
+ axs[1].set_ylabel("X2")
214
+ sc = axs[1].scatter(X[:, 0], X[:, 1], c=y, cmap="viridis")
215
+ fig.colorbar(sc, ax=axs[1], label="y")
216
 
217
  # plot solutions
218
  #for alpha, w, norm, mse in solutions:
 
220
  ##ax.text(w[0], w[1], f"α={alpha}", fontsize=8)
221
 
222
  buf = io.BytesIO()
223
+ fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
224
  plt.close(fig)
225
  buf.seek(0)
226
  img = Image.open(buf)