Spaces:
Sleeping
Sleeping
Commit ·
dfbaa58
1
Parent(s): 6034fcd
Optimised loss and regularization calculations
Browse files- regularization.py +58 -32
regularization.py
CHANGED
|
@@ -42,6 +42,29 @@ def min_corresponding_entries(W1, W2, w1, tol=0.1):
|
|
| 42 |
|
| 43 |
return np.min(values)
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
class Regularization:
|
| 46 |
def __init__(self, width, height):
|
| 47 |
# initialized in draw_plot
|
|
@@ -77,22 +100,23 @@ class Regularization:
|
|
| 77 |
# Regularization strengths
|
| 78 |
self.alphas = [0.01, 0.1, 1, 10, 100]
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
| 96 |
|
| 97 |
#self.Model = Ridge #l2 loss + l2 reg
|
| 98 |
#self.Model = Lasso #l2 loss + l1 reg
|
|
@@ -100,13 +124,13 @@ class Regularization:
|
|
| 100 |
self.loss_type = 'l2'
|
| 101 |
self.reg_type = 'l2'
|
| 102 |
|
| 103 |
-
self.
|
| 104 |
-
self.
|
| 105 |
|
| 106 |
self.reg_levels = [10, 20, 30]
|
| 107 |
self.w1_range = (-100, 100)
|
| 108 |
self.w2_range = (-100, 100)
|
| 109 |
-
self.num_dots =
|
| 110 |
|
| 111 |
self.plot_regularization_path = False
|
| 112 |
|
|
@@ -132,7 +156,7 @@ class Regularization:
|
|
| 132 |
#model = self.Model(alpha=alpha, fit_intercept=False) # no intercept
|
| 133 |
#model.fit(X, y)
|
| 134 |
#w = model.coef_
|
| 135 |
-
#loss = self.
|
| 136 |
#solutions.append((alpha, w, self.Regularizer(w), loss))
|
| 137 |
|
| 138 |
# Extract contour levels from solutions
|
|
@@ -147,16 +171,18 @@ class Regularization:
|
|
| 147 |
|
| 148 |
# compute regularizer surface
|
| 149 |
stacked = np.stack((W1, W2), axis=-1)
|
| 150 |
-
regs = np.apply_along_axis(self.
|
|
|
|
| 151 |
|
| 152 |
-
logger.info("Computing losses " + str(self.
|
| 153 |
# compute loss surface
|
| 154 |
-
losses =
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
| 160 |
|
| 161 |
logger.info("Computing loss levels")
|
| 162 |
reg_levels = self.reg_levels
|
|
@@ -224,17 +250,17 @@ class Regularization:
|
|
| 224 |
plt.close(fig)
|
| 225 |
buf.seek(0)
|
| 226 |
img = Image.open(buf)
|
| 227 |
-
|
| 228 |
return img
|
| 229 |
|
| 230 |
def update_loss(self, loss_type):
|
| 231 |
self.loss_type = loss_type
|
| 232 |
-
self.
|
| 233 |
return self.plot()
|
| 234 |
|
| 235 |
def update_regularizer(self, reg_type):
|
| 236 |
self.reg_type = reg_type
|
| 237 |
-
self.
|
| 238 |
|
| 239 |
return self.plot()
|
| 240 |
|
|
@@ -315,7 +341,7 @@ class Regularization:
|
|
| 315 |
interactive=True)
|
| 316 |
|
| 317 |
# resolution
|
| 318 |
-
slider = gr.Slider(minimum=100, maximum=1000, value=
|
| 319 |
|
| 320 |
# plot path
|
| 321 |
path_checkbox = gr.Checkbox(label="Show regularization path", value=False)
|
|
|
|
| 42 |
|
| 43 |
return np.min(values)
|
| 44 |
|
| 45 |
+
|
| 46 |
+
def l1_norm(W):
|
| 47 |
+
return np.sum(np.abs(W), axis=-1)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def l2_norm(W):
|
| 51 |
+
return np.linalg.norm(W, axis=-1)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def l1_loss(W, y, X):
|
| 55 |
+
num_dots = W.shape[0]
|
| 56 |
+
y = y.reshape(1, -1)
|
| 57 |
+
preds = W.reshape(-1, 2) @ X.T
|
| 58 |
+
return np.mean(np.abs(y - preds), axis=1).reshape(num_dots, num_dots)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def l2_loss(W, y, X):
|
| 62 |
+
num_dots = W.shape[0]
|
| 63 |
+
y = y.reshape(1, -1)
|
| 64 |
+
preds = W.reshape(-1, 2) @ X.T
|
| 65 |
+
return np.mean((y - preds) ** 2, axis=1).reshape(num_dots, num_dots)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
class Regularization:
|
| 69 |
def __init__(self, width, height):
|
| 70 |
# initialized in draw_plot
|
|
|
|
| 100 |
# Regularization strengths
|
| 101 |
self.alphas = [0.01, 0.1, 1, 10, 100]
|
| 102 |
|
| 103 |
+
self.losses = {
|
| 104 |
+
#'l1': mean_absolute_error, # slow
|
| 105 |
+
#'l1': lambda y, pred: np.mean(abs(y - pred)),
|
| 106 |
+
# 'l1': l1_loss,
|
| 107 |
+
#'l2': mean_squared_error, # slow
|
| 108 |
+
#'l2': lambda y, pred: np.mean((y - pred)**2)
|
| 109 |
+
# 'l2': l2_loss
|
| 110 |
+
"l1": l1_loss,
|
| 111 |
+
"l2": l2_loss,
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
self.regularizers = {
|
| 115 |
+
# 'l1': lambda w: sum(abs(w)),
|
| 116 |
+
# 'l2': np.linalg.norm
|
| 117 |
+
"l1": l1_norm,
|
| 118 |
+
"l2": l2_norm,
|
| 119 |
+
}
|
| 120 |
|
| 121 |
#self.Model = Ridge #l2 loss + l2 reg
|
| 122 |
#self.Model = Lasso #l2 loss + l1 reg
|
|
|
|
| 124 |
self.loss_type = 'l2'
|
| 125 |
self.reg_type = 'l2'
|
| 126 |
|
| 127 |
+
self.loss = self.losses[self.loss_type]
|
| 128 |
+
self.regularizer = self.regularizers[self.reg_type]
|
| 129 |
|
| 130 |
self.reg_levels = [10, 20, 30]
|
| 131 |
self.w1_range = (-100, 100)
|
| 132 |
self.w2_range = (-100, 100)
|
| 133 |
+
self.num_dots = 500
|
| 134 |
|
| 135 |
self.plot_regularization_path = False
|
| 136 |
|
|
|
|
| 156 |
#model = self.Model(alpha=alpha, fit_intercept=False) # no intercept
|
| 157 |
#model.fit(X, y)
|
| 158 |
#w = model.coef_
|
| 159 |
+
#loss = self.loss(y, model.predict(X))
|
| 160 |
#solutions.append((alpha, w, self.Regularizer(w), loss))
|
| 161 |
|
| 162 |
# Extract contour levels from solutions
|
|
|
|
| 171 |
|
| 172 |
# compute regularizer surface
|
| 173 |
stacked = np.stack((W1, W2), axis=-1)
|
| 174 |
+
# regs = np.apply_along_axis(self.regularizer, -1, stacked)
|
| 175 |
+
regs = self.regularizer(stacked)
|
| 176 |
|
| 177 |
+
logger.info("Computing losses " + str(self.loss))
|
| 178 |
# compute loss surface
|
| 179 |
+
losses = self.loss(stacked, y, X)
|
| 180 |
+
# losses = np.zeros_like(W1)
|
| 181 |
+
# for i in range(W1.shape[0]):
|
| 182 |
+
# for j in range(W1.shape[1]):
|
| 183 |
+
# w = np.array([W1[i, j], W2[i, j]])
|
| 184 |
+
# y_pred = X @ w
|
| 185 |
+
# losses[i, j] = self.loss(y, y_pred)
|
| 186 |
|
| 187 |
logger.info("Computing loss levels")
|
| 188 |
reg_levels = self.reg_levels
|
|
|
|
| 250 |
plt.close(fig)
|
| 251 |
buf.seek(0)
|
| 252 |
img = Image.open(buf)
|
| 253 |
+
|
| 254 |
return img
|
| 255 |
|
| 256 |
def update_loss(self, loss_type):
|
| 257 |
self.loss_type = loss_type
|
| 258 |
+
self.loss = self.losses[loss_type]
|
| 259 |
return self.plot()
|
| 260 |
|
| 261 |
def update_regularizer(self, reg_type):
|
| 262 |
self.reg_type = reg_type
|
| 263 |
+
self.regularizer = self.regularizers[reg_type]
|
| 264 |
|
| 265 |
return self.plot()
|
| 266 |
|
|
|
|
| 341 |
interactive=True)
|
| 342 |
|
| 343 |
# resolution
|
| 344 |
+
slider = gr.Slider(minimum=100, maximum=1000, value=500, step=1, label="Resolution (#points)")
|
| 345 |
|
| 346 |
# plot path
|
| 347 |
path_checkbox = gr.Checkbox(label="Show regularization path", value=False)
|