joel-woodfield commited on
Commit
dfbaa58
·
1 Parent(s): 6034fcd

Optimised loss and regularization calculations

Browse files
Files changed (1) hide show
  1. regularization.py +58 -32
regularization.py CHANGED
@@ -42,6 +42,29 @@ def min_corresponding_entries(W1, W2, w1, tol=0.1):
42
 
43
  return np.min(values)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  class Regularization:
46
  def __init__(self, width, height):
47
  # initialized in draw_plot
@@ -77,22 +100,23 @@ class Regularization:
77
  # Regularization strengths
78
  self.alphas = [0.01, 0.1, 1, 10, 100]
79
 
80
- def l1_loss(y, pred):
81
- return np.mean(abs(y - pred))
82
-
83
- def l2_loss(y, pred):
84
- return np.mean((y - pred)**2)
85
-
86
- self.Losses = {#'l1': mean_absolute_error, # slow
87
- #'l1': lambda y, pred: np.mean(abs(y - pred)),
88
- 'l1': l1_loss,
89
- #'l2': mean_squared_error, # slow
90
- #'l2': lambda y, pred: np.mean((y - pred)**2)
91
- 'l2': l2_loss
92
- }
93
- self.Regularizers = {'l1': lambda w: sum(abs(w)),
94
- 'l2': np.linalg.norm
95
- }
 
96
 
97
  #self.Model = Ridge #l2 loss + l2 reg
98
  #self.Model = Lasso #l2 loss + l1 reg
@@ -100,13 +124,13 @@ class Regularization:
100
  self.loss_type = 'l2'
101
  self.reg_type = 'l2'
102
 
103
- self.Loss = self.Losses[self.loss_type]
104
- self.Regularizer = self.Regularizers[self.reg_type]
105
 
106
  self.reg_levels = [10, 20, 30]
107
  self.w1_range = (-100, 100)
108
  self.w2_range = (-100, 100)
109
- self.num_dots = 100
110
 
111
  self.plot_regularization_path = False
112
 
@@ -132,7 +156,7 @@ class Regularization:
132
  #model = self.Model(alpha=alpha, fit_intercept=False) # no intercept
133
  #model.fit(X, y)
134
  #w = model.coef_
135
- #loss = self.Loss(y, model.predict(X))
136
  #solutions.append((alpha, w, self.Regularizer(w), loss))
137
 
138
  # Extract contour levels from solutions
@@ -147,16 +171,18 @@ class Regularization:
147
 
148
  # compute regularizer surface
149
  stacked = np.stack((W1, W2), axis=-1)
150
- regs = np.apply_along_axis(self.Regularizer, -1, stacked)
 
151
 
152
- logger.info("Computing losses " + str(self.Loss))
153
  # compute loss surface
154
- losses = np.zeros_like(W1)
155
- for i in range(W1.shape[0]):
156
- for j in range(W1.shape[1]):
157
- w = np.array([W1[i, j], W2[i, j]])
158
- y_pred = X @ w
159
- losses[i, j] = self.Loss(y, y_pred)
 
160
 
161
  logger.info("Computing loss levels")
162
  reg_levels = self.reg_levels
@@ -224,17 +250,17 @@ class Regularization:
224
  plt.close(fig)
225
  buf.seek(0)
226
  img = Image.open(buf)
227
-
228
  return img
229
 
230
  def update_loss(self, loss_type):
231
  self.loss_type = loss_type
232
- self.Loss = self.Losses[loss_type]
233
  return self.plot()
234
 
235
  def update_regularizer(self, reg_type):
236
  self.reg_type = reg_type
237
- self.Regularizer = self.Regularizers[reg_type]
238
 
239
  return self.plot()
240
 
@@ -315,7 +341,7 @@ class Regularization:
315
  interactive=True)
316
 
317
  # resolution
318
- slider = gr.Slider(minimum=100, maximum=1000, value=100, step=1, label="Resolution (#points)")
319
 
320
  # plot path
321
  path_checkbox = gr.Checkbox(label="Show regularization path", value=False)
 
42
 
43
  return np.min(values)
44
 
45
+
46
+ def l1_norm(W):
47
+ return np.sum(np.abs(W), axis=-1)
48
+
49
+
50
+ def l2_norm(W):
51
+ return np.linalg.norm(W, axis=-1)
52
+
53
+
54
+ def l1_loss(W, y, X):
55
+ num_dots = W.shape[0]
56
+ y = y.reshape(1, -1)
57
+ preds = W.reshape(-1, 2) @ X.T
58
+ return np.mean(np.abs(y - preds), axis=1).reshape(num_dots, num_dots)
59
+
60
+
61
+ def l2_loss(W, y, X):
62
+ num_dots = W.shape[0]
63
+ y = y.reshape(1, -1)
64
+ preds = W.reshape(-1, 2) @ X.T
65
+ return np.mean((y - preds) ** 2, axis=1).reshape(num_dots, num_dots)
66
+
67
+
68
  class Regularization:
69
  def __init__(self, width, height):
70
  # initialized in draw_plot
 
100
  # Regularization strengths
101
  self.alphas = [0.01, 0.1, 1, 10, 100]
102
 
103
+ self.losses = {
104
+ #'l1': mean_absolute_error, # slow
105
+ #'l1': lambda y, pred: np.mean(abs(y - pred)),
106
+ # 'l1': l1_loss,
107
+ #'l2': mean_squared_error, # slow
108
+ #'l2': lambda y, pred: np.mean((y - pred)**2)
109
+ # 'l2': l2_loss
110
+ "l1": l1_loss,
111
+ "l2": l2_loss,
112
+ }
113
+
114
+ self.regularizers = {
115
+ # 'l1': lambda w: sum(abs(w)),
116
+ # 'l2': np.linalg.norm
117
+ "l1": l1_norm,
118
+ "l2": l2_norm,
119
+ }
120
 
121
  #self.Model = Ridge #l2 loss + l2 reg
122
  #self.Model = Lasso #l2 loss + l1 reg
 
124
  self.loss_type = 'l2'
125
  self.reg_type = 'l2'
126
 
127
+ self.loss = self.losses[self.loss_type]
128
+ self.regularizer = self.regularizers[self.reg_type]
129
 
130
  self.reg_levels = [10, 20, 30]
131
  self.w1_range = (-100, 100)
132
  self.w2_range = (-100, 100)
133
+ self.num_dots = 500
134
 
135
  self.plot_regularization_path = False
136
 
 
156
  #model = self.Model(alpha=alpha, fit_intercept=False) # no intercept
157
  #model.fit(X, y)
158
  #w = model.coef_
159
+ #loss = self.loss(y, model.predict(X))
160
  #solutions.append((alpha, w, self.Regularizer(w), loss))
161
 
162
  # Extract contour levels from solutions
 
171
 
172
  # compute regularizer surface
173
  stacked = np.stack((W1, W2), axis=-1)
174
+ # regs = np.apply_along_axis(self.regularizer, -1, stacked)
175
+ regs = self.regularizer(stacked)
176
 
177
+ logger.info("Computing losses " + str(self.loss))
178
  # compute loss surface
179
+ losses = self.loss(stacked, y, X)
180
+ # losses = np.zeros_like(W1)
181
+ # for i in range(W1.shape[0]):
182
+ # for j in range(W1.shape[1]):
183
+ # w = np.array([W1[i, j], W2[i, j]])
184
+ # y_pred = X @ w
185
+ # losses[i, j] = self.loss(y, y_pred)
186
 
187
  logger.info("Computing loss levels")
188
  reg_levels = self.reg_levels
 
250
  plt.close(fig)
251
  buf.seek(0)
252
  img = Image.open(buf)
253
+
254
  return img
255
 
256
  def update_loss(self, loss_type):
257
  self.loss_type = loss_type
258
+ self.loss = self.losses[loss_type]
259
  return self.plot()
260
 
261
  def update_regularizer(self, reg_type):
262
  self.reg_type = reg_type
263
+ self.regularizer = self.regularizers[reg_type]
264
 
265
  return self.plot()
266
 
 
341
  interactive=True)
342
 
343
  # resolution
344
+ slider = gr.Slider(minimum=100, maximum=1000, value=500, step=1, label="Resolution (#points)")
345
 
346
  # plot path
347
  path_checkbox = gr.Checkbox(label="Show regularization path", value=False)