nanye commited on
Commit
56a0c53
·
1 Parent(s): eed30e6

initial commit

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. regularization.py +286 -0
  3. usage.md +1 -0
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: yellow
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.46.0
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.46.0
8
+ app_file: regularization.py
9
  pinned: false
10
  ---
11
 
regularization.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from pathlib import Path
3
+ import pickle
4
+
5
+ import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ import inspect
8
+ import numpy as np
9
+ import pandas as pd
10
+ import io
11
+ from jinja2 import Template
12
+ from PIL import Image
13
+ import sklearn
14
+ from sklearn.linear_model import LogisticRegression
15
+ from sklearn.svm import LinearSVC
16
+ from sklearn.datasets import load_iris
17
+ from sklearn.metrics import classification_report, mean_squared_error, mean_absolute_error
18
+ from sklearn.datasets import make_regression
19
+ from sklearn.linear_model import Ridge
20
+ from sklearn.linear_model import Lasso
21
+
22
+ import traceback
23
+ import yaml
24
+
25
+ import logging
26
+ logging.basicConfig(
27
+ level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL)
28
+ format="%(asctime)s [%(levelname)s] %(message)s", # log format
29
+ )
30
+ logger = logging.getLogger("ELVIS")
31
+
32
+ def min_corresponding_entries(W1, W2, w1, tol=0.1):
33
+ #mask = np.isclose(W1, w1, atol=tol, rtol=0)
34
+ mask = (W1 <= w1)
35
+ #print(W1.max(), W1.min(), w1)
36
+
37
+ values = W2[mask]
38
+
39
+ if values.size == 0:
40
+ raise ValueError("No entries in W1 approximately equal to w1")
41
+
42
+ return np.min(values)
43
+
44
+ class Regularization:
45
+ def __init__(self, width, height):
46
+ # initialized in draw_plot
47
+ #self.canvas_width = -1
48
+ #self.canvas_height = -1
49
+
50
+ self.canvas_width = width
51
+ self.canvas_height = height
52
+
53
+ self.css ="""
54
+ #my-button {
55
+ height: 30px;
56
+ font-size: 16px;
57
+ }
58
+
59
+ #rowheight {
60
+ height: 90px;
61
+ }
62
+
63
+ .hidden-button {
64
+ display: none;
65
+ }
66
+
67
+ .report-table {
68
+ border: 0 !important;
69
+ }
70
+ .report-table tr, .report-table th, .report-table td, .report-table tbody, .report-table thead {
71
+ border: 0 !important;
72
+ padding: 6px 12px;
73
+ text-align: center;
74
+ }"""
75
+
76
+ # Regularization strengths
77
+ self.alphas = [0.01, 0.1, 1, 10, 100]
78
+
79
+ def l1_loss(y, pred):
80
+ return np.mean(abs(y - pred))
81
+
82
+ def l2_loss(y, pred):
83
+ return np.mean((y - pred)**2)
84
+
85
+ self.Losses = {#'l1': mean_absolute_error, # slow
86
+ #'l1': lambda y, pred: np.mean(abs(y - pred)),
87
+ 'l1': l1_loss,
88
+ #'l2': mean_squared_error, # slow
89
+ #'l2': lambda y, pred: np.mean((y - pred)**2)
90
+ 'l2': l2_loss
91
+ }
92
+ self.Regularizers = {'l1': lambda w: sum(abs(w)),
93
+ 'l2': np.linalg.norm
94
+ }
95
+
96
+ #self.Model = Ridge #l2 loss + l2 reg
97
+ #self.Model = Lasso #l2 loss + l1 reg
98
+
99
+ self.loss_type = 'l2'
100
+ self.reg_type = 'l2'
101
+
102
+ self.Loss = self.Losses[self.loss_type]
103
+ self.Regularizer = self.Regularizers[self.reg_type]
104
+
105
+ self.reg_levels = [10, 20, 30]
106
+
107
+ def plot(self):
108
+ '''
109
+ '''
110
+
111
+ logger.info("Initializing figure")
112
+ fig = plt.figure(figsize=(self.canvas_width/100., self.canvas_height/100.0), dpi=100)
113
+ # set entire figure to be the canvas to allow simple conversion of mouse
114
+ # position to coordinates in the figure
115
+ ax = fig.add_axes([0., 0., 1., 1.]) #
116
+ ax.margins(x=0, y=0) # no padding in both directions
117
+
118
+ # make a synthetic dataset with 2 features
119
+ X, y = make_regression(n_samples=200, n_features=2, noise=15, random_state=0)
120
+
121
+ # fit a regularized linear models and record parameters, regularizer value, and loss value
122
+ #solutions = []
123
+ #for alpha in self.alphas:
124
+ ## TODO: use PyTorch or cvx to implement a linear model class that supports
125
+ ## different types of losses and regularizers
126
+ #model = self.Model(alpha=alpha, fit_intercept=False) # no intercept
127
+ #model.fit(X, y)
128
+ #w = model.coef_
129
+ #loss = self.Loss(y, model.predict(X))
130
+ #solutions.append((alpha, w, self.Regularizer(w), loss))
131
+
132
+ # Extract contour levels from solutions
133
+ #reg_levels = [sol[2] for sol in solutions]
134
+ #reg_levels.reverse()
135
+ #loss_levels = [sol[3] for sol in solutions]
136
+
137
+ # build grid in parameter space
138
+ w1 = np.linspace(-100, 100, 400)
139
+ w2 = np.linspace(-100, 100, 400)
140
+ W1, W2 = np.meshgrid(w1, w2)
141
+
142
+ # compute regularizer surface
143
+ stacked = np.stack((W1, W2), axis=-1)
144
+ regs = np.apply_along_axis(self.Regularizer, -1, stacked)
145
+
146
+ logger.info("Computing losses " + str(self.Loss))
147
+ # compute loss surface
148
+ losses = np.zeros_like(W1)
149
+ for i in range(W1.shape[0]):
150
+ for j in range(W1.shape[1]):
151
+ w = np.array([W1[i, j], W2[i, j]])
152
+ y_pred = X @ w
153
+ losses[i, j] = self.Loss(y, y_pred)
154
+
155
+ logger.info("Computing loss levels")
156
+ reg_levels = self.reg_levels
157
+ loss_levels = [min_corresponding_entries(regs, losses, reg_level) for reg_level in reg_levels]
158
+ loss_levels.reverse()
159
+ print(reg_levels)
160
+ print(loss_levels)
161
+
162
+ # plot contour plots
163
+ fig = plt.figure(figsize=(5, 5))
164
+ ax = plt.gca()
165
+ ax.set_title("")
166
+ ax.set_xlabel("w1")
167
+ ax.set_ylabel("w2")
168
+
169
+
170
+ cmap = plt.get_cmap("viridis")
171
+ N = len(reg_levels)
172
+ colors = [cmap(i / (N - 1)) for i in range(N)]
173
+
174
+ # regularizer contours
175
+ cs1 = ax.contour(W1, W2, regs, levels=reg_levels, colors=colors)
176
+ #ax.clabel(cs1, inline=True, fontsize=8) # show contour levels
177
+
178
+ # loss contours
179
+ cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
180
+ #ax.clabel(cs2, inline=True, fontsize=8)
181
+
182
+ # plot solutions
183
+ #for alpha, w, norm, mse in solutions:
184
+ #ax.plot(w[0], w[1], "ro")
185
+ ##ax.text(w[0], w[1], f"α={alpha}", fontsize=8)
186
+
187
+ buf = io.BytesIO()
188
+ ax.figure.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
189
+ plt.close(fig)
190
+ buf.seek(0)
191
+ img = Image.open(buf)
192
+
193
+ return img
194
+
195
+ def update_loss(self, loss_type):
196
+ self.loss_type = loss_type
197
+ self.Loss = self.Losses[loss_type]
198
+ return self.plot()
199
+
200
+ def update_regularizer(self, reg_type):
201
+ self.reg_type = reg_type
202
+ self.Regularizer = self.Regularizers[reg_type]
203
+
204
+ return self.plot()
205
+
206
+ def update_reg_levels(self, reg_levels):
207
+ self.reg_levels = [float(reg_level) for reg_level in reg_levels.split(",")]
208
+
209
+ return self.plot()
210
+
211
+ def update_resolution(self, num_dots):
212
+ self.num_dots = num_dots
213
+ return self.plot()
214
+
215
+ def launch(self):
216
+ # build the Gradio interface
217
+ with gr.Blocks(css=self.css) as demo:
218
+ # app title
219
+ gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Regularization visualizer</div>")
220
+
221
+ # GUI elements and layout
222
+ with gr.Row():
223
+ with gr.Column(scale=2):
224
+ self.data_image = gr.Image(value=self.plot(), container=True)
225
+
226
+ with gr.Column(scale=1):
227
+ with gr.Tab("Settings"):
228
+ dataset_radio = gr.Radio(["make_regression", "Upload"],
229
+ value="make_regression", label="Dataset type", elem_id="rowheight")
230
+
231
+ # upload data
232
+ file_chooser = gr.File(label="Choose a file", visible=False, elem_id="rowheight")
233
+ self.file_chooser = file_chooser
234
+
235
+ # loss type
236
+ loss_type = gr.Dropdown(choices=['l1', 'l2'],
237
+ label='Loss type',
238
+ value='l2',
239
+ visible=True)
240
+
241
+ # regularizer type
242
+ regularizer_type = gr.Dropdown(choices=['l1', 'l2', 'elastic-net'],
243
+ label='Regularizer type',
244
+ value='l2',
245
+ visible=True)
246
+
247
+ # regularization strength
248
+ #reg_textbox = gr.Textbox(label="Regularization constants")
249
+ reg_textbox = gr.Textbox(label="Regularizer levels",
250
+ value="10, 20, 30",
251
+ interactive=True)
252
+ self.reg_textbox = reg_textbox
253
+
254
+ with gr.Tab("Export"):
255
+ # use hidden download button to generate files on the fly
256
+ # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634
257
+
258
+ btn_export_data = gr.Button("Data")
259
+ btn_export_data_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_data_hidden", elem_classes="hidden-button")
260
+
261
+ btn_export_model = gr.Button('Model')
262
+ btn_export_model_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_model_hidden", elem_classes="hidden-button")
263
+
264
+ btn_export_code = gr.Button('Code')
265
+ btn_export_code_hidden = gr.DownloadButton(label="You should not see this", elem_id="btn_export_code_hidden", elem_classes="hidden-button")
266
+
267
+ with gr.Tab("Options"):
268
+ slider = gr.Slider(minimum=100, maximum=1000, value=100, step=1, label="Resolution (#points)")
269
+
270
+ with gr.Tab("Usage"):
271
+ gr.Markdown(''.join(open('usage.md', 'r').readlines()))
272
+
273
+
274
+ # event handlers for GUI elements
275
+ loss_type.change(fn=self.update_loss, inputs=loss_type,
276
+ outputs=self.data_image)
277
+ regularizer_type.change(fn=self.update_regularizer,
278
+ inputs=regularizer_type, outputs=self.data_image)
279
+
280
+ reg_textbox.submit(self.update_reg_levels, inputs=reg_textbox,
281
+ outputs=self.data_image)
282
+
283
+ demo.launch()
284
+
285
+ visualizer = Regularization(width=1200, height=900)
286
+ visualizer.launch()
usage.md ADDED
@@ -0,0 +1 @@
 
 
1
+ **Quick start**