joel-woodfield commited on
Commit
770d448
·
1 Parent(s): 43700b8

Refactor code to separate frontend and backend

Browse files
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: yellow
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.46.0
8
- app_file: regularization.py
9
  pinned: false
10
  ---
11
 
 
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.46.0
8
+ app_file: frontend.py
9
  pinned: false
10
  ---
11
 
backend.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+ import cvxpy as cp
5
+ import numpy as np
6
+ from sympy import Expr, lambdify
7
+
8
+
9
+
10
+ @dataclass
11
+ class DataGenerationOptions:
12
+ method: Literal["grid", "random"]
13
+ num_samples: int
14
+ noise: float = 0.
15
+
16
+
17
+ @dataclass
18
+ class Dataset:
19
+ x1: list[float]
20
+ x2: list[float]
21
+ y: list[float]
22
+
23
+
24
+ @dataclass
25
+ class PlotsData:
26
+ W1: np.ndarray
27
+ W2: np.ndarray
28
+ loss_values: np.ndarray
29
+ norms: np.ndarray
30
+ loss_levels: list[float]
31
+ reg_levels: list[float]
32
+ unreg_solution: np.ndarray
33
+ path: np.ndarray
34
+
35
+
36
+ def generate_dataset(
37
+ function: Expr,
38
+ x1_lim: tuple[int, int],
39
+ x2_lim: tuple[int, int],
40
+ generation_options: DataGenerationOptions,
41
+ ) -> Dataset:
42
+ f = lambdify(('x1', 'x2'), function, modules='numpy')
43
+
44
+ if generation_options.method == 'grid':
45
+ x1 = np.linspace(x1_lim[0], x1_lim[1], int(np.sqrt(generation_options.num_samples)))
46
+ x2 = np.linspace(x2_lim[0], x2_lim[1], int(np.sqrt(generation_options.num_samples)))
47
+ X1, X2 = np.meshgrid(x1, x2)
48
+ X1_flat = X1.flatten()
49
+ X2_flat = X2.flatten()
50
+ elif generation_options.method == 'random':
51
+ X1_flat = np.random.uniform(x1_lim[0], x1_lim[1], generation_options.num_samples)
52
+ X2_flat = np.random.uniform(x2_lim[0], x2_lim[1], generation_options.num_samples)
53
+ else:
54
+ raise ValueError(f"Unknown generation method: {generation_options.method}")
55
+
56
+ Y = f(X1_flat, X2_flat)
57
+
58
+ if generation_options.noise > 0:
59
+ Y += np.random.normal(0, generation_options.noise, size=Y.shape)
60
+
61
+ return Dataset(x1=X1_flat.tolist(), x2=X2_flat.tolist(), y=Y.tolist())
62
+
63
+
64
+ def load_dataset_from_csv(
65
+ file_path: str, header: bool, x1_col: int, x2_col: int, y_col: int
66
+ ) -> Dataset:
67
+ # data = np.loadtxt(file_path, delimiter=',', skiprows=1 if header else 0)
68
+ data = np.genfromtxt(file_path, delimiter=',', skip_header=1 if header else 0)
69
+ data = data[~np.isnan(data).any(axis=1)] # remove rows with NaN values
70
+
71
+ x1 = data[:, x1_col].tolist()
72
+ x2 = data[:, x2_col].tolist()
73
+ y = data[:, y_col].tolist()
74
+ return Dataset(x1=x1, x2=x2, y=y)
75
+
76
+
77
+ def build_parameter_grid(
78
+ w1_lim: tuple[float, float],
79
+ w2_lim: tuple[float, float],
80
+ min_num_points: int,
81
+ ) -> tuple[np.ndarray, np.ndarray]:
82
+ w1 = np.linspace(w1_lim[0], w1_lim[1], min_num_points)
83
+ w2 = np.linspace(w2_lim[0], w2_lim[1], min_num_points)
84
+
85
+ # make sure (0, 0) is included
86
+ if 0 not in w1:
87
+ w1 = np.insert(w1, np.searchsorted(w1, 0), 0)
88
+ if 0 not in w2:
89
+ w2 = np.insert(w2, np.searchsorted(w2, 0), 0)
90
+
91
+ W1, W2 = np.meshgrid(w1, w2)
92
+ return W1, W2
93
+
94
+
95
+ def compute_loss(
96
+ dataset: Dataset,
97
+ w1: np.ndarray,
98
+ w2: np.ndarray,
99
+ loss: Literal["l1", "l2"],
100
+ ) -> np.ndarray:
101
+ x1 = np.array(dataset.x1)
102
+ x2 = np.array(dataset.x2)
103
+ y = np.array(dataset.y)
104
+ grid_size = w1.shape[0]
105
+
106
+ W = np.stack([w1.flatten(), w2.flatten()], axis=-1) # (D^2, 2)
107
+ X = np.stack([x1, x2], axis=0) # (2, N)
108
+ y_pred = W @ X
109
+
110
+ y = y.reshape(1, -1)
111
+
112
+ if loss == 'l2':
113
+ return np.mean((y - y_pred) ** 2, axis=1).reshape(grid_size, grid_size)
114
+ elif loss == 'l1':
115
+ return np.mean(np.abs(y - y_pred), axis=1).reshape(grid_size, grid_size)
116
+
117
+
118
+ def compute_norms(
119
+ w1: np.ndarray,
120
+ w2: np.ndarray,
121
+ norm: Literal["l1", "l2"],
122
+ ) -> np.ndarray:
123
+ if norm == "l2":
124
+ return np.sqrt(w1 ** 2 + w2 ** 2)
125
+ elif norm == "l1":
126
+ return np.abs(w1) + np.abs(w2)
127
+
128
+
129
+ def compute_loss_levels(
130
+ loss_values: np.ndarray,
131
+ norms: np.ndarray,
132
+ reg_levels: list[float],
133
+ ) -> list[float]:
134
+ levels = []
135
+ for reg_level in reg_levels:
136
+ satisfying = loss_values[norms <= reg_level]
137
+ if satisfying.size == 0:
138
+ raise ValueError(f"No satisfying loss level for reg_level {reg_level}")
139
+
140
+ optimal_satisfying = np.min(satisfying)
141
+ levels.append(optimal_satisfying)
142
+
143
+ # ensure ascending order and no duplicates
144
+ levels = list(set(levels))
145
+ levels = sorted(levels)
146
+
147
+ return levels
148
+
149
+
150
+ def compute_unregularized_solution(
151
+ dataset: Dataset,
152
+ w1_range: tuple[float, float],
153
+ w2_range: tuple[float, float],
154
+ num_dots: int = 100,
155
+ ) -> np.ndarray:
156
+ x1 = np.array(dataset.x1)
157
+ x2 = np.array(dataset.x2)
158
+ y = np.array(dataset.y)
159
+
160
+ X = np.stack([x1, x2], axis=-1) # (N, 2)
161
+
162
+ try:
163
+ # find point solution if exists
164
+ w_opt = np.linalg.solve(X.T @ X, X.T @ y)
165
+
166
+ except np.linalg.LinAlgError:
167
+ # the solutions are on a line
168
+ eig_vals, eig_vecs = np.linalg.eigh(X.T @ X)
169
+
170
+ line_direction = eig_vecs[:, np.argmin(eig_vals)]
171
+ m = line_direction[1] / line_direction[0]
172
+
173
+ candidate_w = np.linalg.lstsq(X, y, rcond=None)[0]
174
+ b = candidate_w[1] - m * candidate_w[0]
175
+
176
+ w1_opt = np.linspace(w1_range[0], w1_range[1], num_dots)
177
+ w2_opt = m * w1_opt + b
178
+ w_opt = np.stack((w1_opt, w2_opt), axis=-1)
179
+
180
+ mask = (w2_opt <= w2_range[1]) & (w2_opt >= w2_range[0])
181
+ w_opt = w_opt[mask]
182
+
183
+ return w_opt
184
+
185
+
186
+ def compute_regularization_path(
187
+ dataset: Dataset,
188
+ loss_type: Literal["l1", "l2"],
189
+ regularizer_type: Literal["l1", "l2"],
190
+ ) -> np.ndarray:
191
+ x1 = np.array(dataset.x1)
192
+ x2 = np.array(dataset.x2)
193
+ y = np.array(dataset.y)
194
+
195
+ X = np.stack([x1, x2], axis=1) # (N, 2)
196
+
197
+ w = cp.Variable(2)
198
+ lambd = cp.Parameter(nonneg=True)
199
+
200
+ if loss_type == "l2":
201
+ loss_expr = cp.sum_squares(y - X @ w)
202
+ elif loss_type == "l1":
203
+ loss_expr = cp.norm1(y - X @ w)
204
+ else:
205
+ raise ValueError(f"Unknown loss type: {loss_type}")
206
+
207
+ if regularizer_type == "l2":
208
+ reg_expr = cp.sum_squares(w)
209
+ elif regularizer_type == "l1":
210
+ reg_expr = cp.norm1(w)
211
+ else:
212
+ raise ValueError(f"Unknown regularizer type: {regularizer_type}")
213
+
214
+ objective = cp.Minimize(loss_expr + lambd * reg_expr)
215
+ problem = cp.Problem(objective)
216
+
217
+ # todo - user defined reg levels
218
+ reg_levels = np.logspace(-4, 4, 100)
219
+
220
+ # solve with reg levels in descending order for using warm start
221
+ w_solutions = []
222
+ for reg_level in sorted(reg_levels, reverse=True):
223
+ lambd.value = reg_level
224
+ problem.solve(warm_start=True)
225
+
226
+ if w.value is None:
227
+ w_solutions.append(np.array([np.nan, np.nan]))
228
+ else:
229
+ w_solutions.append(w.value.copy())
230
+
231
+ return np.array(w_solutions)
232
+
233
+
234
+ def compute_plot_values(
235
+ dataset: Dataset,
236
+ loss_type: Literal["l1", "l2"],
237
+ regularizer_type: Literal["l1", "l2"],
238
+ reg_levels: list[float],
239
+ w1_range: tuple[float, float],
240
+ w2_range: tuple[float, float],
241
+ resolution: int,
242
+ ) -> PlotsData:
243
+ W1, W2 = build_parameter_grid(w1_range, w2_range, resolution)
244
+ loss_values = compute_loss(dataset, W1, W2, loss_type)
245
+ norms = compute_norms(W1, W2, regularizer_type)
246
+ loss_levels = compute_loss_levels(loss_values, norms, reg_levels)
247
+ unreg_solution = compute_unregularized_solution(dataset, w1_range, w2_range)
248
+ path = compute_regularization_path(
249
+ dataset,
250
+ loss_type,
251
+ regularizer_type,
252
+ )
253
+
254
+ return PlotsData(
255
+ W1=W1,
256
+ W2=W2,
257
+ loss_values=loss_values,
258
+ norms=norms,
259
+ loss_levels=loss_levels,
260
+ reg_levels=reg_levels,
261
+ unreg_solution=unreg_solution,
262
+ path=path,
263
+ )
frontend.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import Literal
3
+
4
+ import gradio as gr
5
+ from matplotlib.figure import Figure
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.lines as mlines
8
+ import numpy as np
9
+ from sympy import sympify
10
+
11
+ from backend import (
12
+ compute_plot_values,
13
+ generate_dataset,
14
+ load_dataset_from_csv,
15
+ Dataset,
16
+ DataGenerationOptions,
17
+ PlotsData,
18
+ )
19
+
20
+ CSS = """
21
+ .hidden-button {
22
+ display: none;
23
+ }
24
+ """
25
+
26
+
27
+ def get_dataset(
28
+ dataset_type: str,
29
+ function: str,
30
+ x1_range_input: str,
31
+ x2_range_input: str,
32
+ x_selection_method: str,
33
+ sigma: float,
34
+ nsample: int,
35
+ csv_file: str,
36
+ has_header: bool,
37
+ x1_col: int,
38
+ x2_col: int,
39
+ y_col: int,
40
+ ) -> Dataset:
41
+ if dataset_type == "Generate":
42
+ try:
43
+ function = sympify(function)
44
+ except Exception as e:
45
+ raise ValueError(f"Invalid function: {e}")
46
+
47
+ x1_range = tuple(float(x.strip()) for x in x1_range_input.split(","))
48
+ x2_range = tuple(float(x.strip()) for x in x2_range_input.split(","))
49
+
50
+ if (len(x1_range) != 2 or len(x2_range) != 2):
51
+ raise ValueError("x1_range and x2_range must be tuples of length 2")
52
+
53
+ x_selection_method = x_selection_method.lower()
54
+ if x_selection_method not in ("grid", "random"):
55
+ raise ValueError(f"Invalid x_selection_method: {x_selection_method}")
56
+
57
+ dataset = generate_dataset(
58
+ function,
59
+ x1_range,
60
+ x2_range,
61
+ DataGenerationOptions(
62
+ x_selection_method,
63
+ nsample,
64
+ sigma,
65
+ )
66
+ )
67
+
68
+ elif dataset_type == "CSV":
69
+ try:
70
+ dataset = load_dataset_from_csv(
71
+ csv_file,
72
+ has_header,
73
+ x1_col,
74
+ x2_col,
75
+ y_col,
76
+ )
77
+ except Exception as e:
78
+ gr.Info(f"Error loading CSV: {e}")
79
+ raise e
80
+
81
+ else:
82
+ raise ValueError(f"Invalid dataset_type: {dataset_type}")
83
+
84
+ return dataset
85
+
86
+
87
+ def parse_plot_settings(
88
+ dataset: Dataset,
89
+ loss_type: str,
90
+ regularizer_type: str,
91
+ reg_levels_input: str,
92
+ w1_range_input: str,
93
+ w2_range_input: str,
94
+ resolution: int,
95
+ ) -> tuple[Dataset, Literal["l1", "l2"], Literal["l1", "l2"], list[float], tuple[float, float], tuple[float, float], int]:
96
+ reg_levels = [float(x.strip()) for x in reg_levels_input.split(",")]
97
+ w1_range = tuple(float(x.strip()) for x in w1_range_input.split(","))
98
+ w2_range = tuple(float(x.strip()) for x in w2_range_input.split(","))
99
+
100
+ if loss_type not in ("l1", "l2"):
101
+ raise ValueError(f"Invalid loss_type: {loss_type}")
102
+ if regularizer_type not in ("l1", "l2"):
103
+ raise ValueError(f"Invalid regularizer_type: {regularizer_type}")
104
+
105
+ if len(w1_range) != 2 or len(w2_range) != 2:
106
+ raise ValueError("w1_range and w2_range must be tuples of length 2")
107
+
108
+ return (
109
+ dataset,
110
+ loss_type,
111
+ regularizer_type,
112
+ reg_levels,
113
+ w1_range,
114
+ w2_range,
115
+ resolution,
116
+ )
117
+
118
+
119
+ def generate_contour_plot(
120
+ W1: np.ndarray,
121
+ W2: np.ndarray,
122
+ losses: np.ndarray,
123
+ norms: np.ndarray,
124
+ loss_levels: list[float],
125
+ reg_levels: list[float],
126
+ unreg_solution: np.ndarray,
127
+ path: np.ndarray,
128
+ ) -> Figure:
129
+ fig, ax = plt.subplots(figsize=(8, 8))
130
+ ax.set_title("")
131
+ ax.set_xlabel("w1")
132
+ ax.set_ylabel("w2")
133
+
134
+ cmap = plt.get_cmap("viridis")
135
+ N = len(reg_levels)
136
+ colors = [cmap(i / (N - 1)) for i in range(N)]
137
+
138
+ # regularizer contours
139
+ cs1 = ax.contour(W1, W2, norms, levels=reg_levels, colors=colors, linestyles="dashed")
140
+ ax.clabel(cs1, inline=True, fontsize=8) # show contour levels
141
+
142
+ # loss contours
143
+ cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
144
+ ax.clabel(cs2, inline=True, fontsize=8)
145
+
146
+ # unregularized solution
147
+ if unreg_solution.ndim == 1:
148
+ ax.plot(unreg_solution[0], unreg_solution[1], "bx", markersize=5, label="unregularized solution")
149
+ else:
150
+ ax.plot(unreg_solution[:, 0], unreg_solution[:, 1], "b-", label="unregularized solution")
151
+
152
+ ax.plot(path[:, 0], path[:, 1], "r-", label="regularization path")
153
+
154
+ # legend
155
+ loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
156
+ reg_line = mlines.Line2D([], [], color='black', linestyle='--', label='regularization')
157
+ handles = [loss_line, reg_line]
158
+
159
+ path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
160
+ handles.append(path_line)
161
+
162
+ if unreg_solution.ndim == 1:
163
+ handles.append(
164
+ mlines.Line2D([], [], color='blue', marker='x', linestyle='None', label='unregularized solution')
165
+ )
166
+ else:
167
+ handles.append(
168
+ mlines.Line2D([], [], color='blue', linestyle='-', label='unregularized solution')
169
+ )
170
+
171
+ ax.legend(handles=handles)
172
+
173
+ ax.grid(True)
174
+
175
+ return fig
176
+
177
+
178
+ def generate_data_plot(dataset: Dataset) -> Figure:
179
+ fig, ax = plt.subplots(figsize=(8, 8))
180
+ ax.set_xlabel("x1")
181
+ ax.set_ylabel("x2")
182
+
183
+ sc = ax.scatter(dataset.x1, dataset.x2, c=dataset.y, cmap='viridis')
184
+ ax.grid(True)
185
+ fig.colorbar(sc, ax=ax)
186
+
187
+ return fig
188
+
189
+
190
+ def generate_strength_plot(
191
+ path: np.ndarray,
192
+ reg_levels: np.ndarray,
193
+ ):
194
+ fig, ax = plt.subplots(figsize=(8, 6))
195
+ ax.set_xlabel("Regularization Strength")
196
+ ax.set_ylabel("Weight")
197
+
198
+ ax.plot(reg_levels, path[:, 0], 'r-', label='w1')
199
+ ax.plot(reg_levels, path[:, 1], 'b-', label='w2')
200
+
201
+ ax.set_xscale('log')
202
+ ax.legend()
203
+ ax.grid(True)
204
+
205
+ return fig
206
+
207
+
208
+ def generate_all_plots(
209
+ dataset: Dataset,
210
+ plots_data: PlotsData,
211
+ ) -> tuple[Figure, Figure]:
212
+ contour_plot = generate_contour_plot(
213
+ plots_data.W1,
214
+ plots_data.W2,
215
+ plots_data.loss_values,
216
+ plots_data.norms,
217
+ plots_data.loss_levels,
218
+ plots_data.reg_levels,
219
+ plots_data.unreg_solution,
220
+ plots_data.path,
221
+ )
222
+
223
+ data_plot = generate_data_plot(
224
+ dataset
225
+ )
226
+
227
+ strength_plot = generate_strength_plot(
228
+ plots_data.path,
229
+ np.logspace(-4, 4, 100),
230
+ )
231
+
232
+ return contour_plot, data_plot, strength_plot
233
+
234
+
235
+ def handle_dataset_type_change(dataset_type: Literal["Generate", "CSV"]):
236
+ if dataset_type == "Generate":
237
+ return (
238
+ gr.update(visible=True), # function
239
+ gr.update(visible=True), # x1_textbox
240
+ gr.update(visible=True), # x2_textbox
241
+ gr.update(visible=True), # x_selection_method
242
+ gr.update(visible=True), # sigma
243
+ gr.update(visible=True), # nsample
244
+ gr.update(visible=False), # csv_file
245
+ gr.update(visible=False), # has_header
246
+ gr.update(visible=False), # x1_col
247
+ gr.update(visible=False), # x2_col
248
+ gr.update(visible=False), # y_col
249
+ )
250
+ else: # CSV
251
+ return (
252
+ gr.update(visible=False), # function
253
+ gr.update(visible=False), # x1_textbox
254
+ gr.update(visible=False), # x2_textbox
255
+ gr.update(visible=False), # x_selection_method
256
+ gr.update(visible=False), # sigma
257
+ gr.update(visible=False), # nsample
258
+ gr.update(visible=True), # csv_file
259
+ gr.update(visible=True), # has_header
260
+ gr.update(visible=True), # x1_col
261
+ gr.update(visible=True), # x2_col
262
+ gr.update(visible=True), # y_col
263
+ )
264
+
265
+
266
+ def launch():
267
+ default_dataset_type = "Generate"
268
+
269
+ default_function = "-50 * x1 + 30 * x2"
270
+ default_x1_range = "-1, 1"
271
+ default_x2_range = "-1, 1"
272
+ default_x_selection_method = "Grid"
273
+ default_sigma = 0.1
274
+ default_num_points = 100
275
+
276
+ default_csv_file = ""
277
+ default_has_header = False
278
+ default_x1_col = 0
279
+ default_x2_col = 1
280
+ default_y_col = 2
281
+
282
+ default_loss_type = "l2"
283
+ default_regularizer_type = "l2"
284
+ default_reg_levels = "10, 20, 30"
285
+ default_w1_range = "-100, 100"
286
+ default_w2_range = "-100, 100"
287
+ default_resolution = 100
288
+
289
+ with gr.Blocks() as demo:
290
+ gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Regularization visualizer</div>")
291
+
292
+ dataset = gr.State(
293
+ get_dataset(
294
+ default_dataset_type,
295
+ default_function,
296
+ default_x1_range,
297
+ default_x2_range,
298
+ default_x_selection_method,
299
+ default_sigma,
300
+ default_num_points,
301
+ default_csv_file,
302
+ default_has_header,
303
+ default_x1_col,
304
+ default_x2_col,
305
+ default_y_col,
306
+ )
307
+ )
308
+
309
+ # wrapped PlotsData
310
+ plots_data = gr.State(
311
+ compute_plot_values(
312
+ *parse_plot_settings(
313
+ dataset.value,
314
+ default_loss_type,
315
+ default_regularizer_type,
316
+ default_reg_levels,
317
+ default_w1_range,
318
+ default_w2_range,
319
+ default_resolution,
320
+ )
321
+ )
322
+ )
323
+
324
+ with gr.Row():
325
+ with gr.Column(scale=2):
326
+ with gr.Tab("Contours"):
327
+ main_plot = gr.Plot(
328
+ value=generate_contour_plot(
329
+ plots_data.value.W1,
330
+ plots_data.value.W2,
331
+ plots_data.value.loss_values,
332
+ plots_data.value.norms,
333
+ plots_data.value.loss_levels,
334
+ plots_data.value.reg_levels,
335
+ plots_data.value.unreg_solution,
336
+ plots_data.value.path,
337
+ )
338
+ )
339
+ with gr.Tab("Data"):
340
+ # todo
341
+ data_plot = gr.Plot(
342
+ value=generate_data_plot(
343
+ dataset.value
344
+ )
345
+ )
346
+ with gr.Tab("Strength"):
347
+ # todo
348
+ strength_plot = gr.Plot(
349
+ value=generate_strength_plot(
350
+ plots_data.value.path,
351
+ np.logspace(-4, 4, 100), # todo
352
+ )
353
+ )
354
+
355
+ with gr.Column(scale=1):
356
+ with gr.Tab("Data"):
357
+ with gr.Row():
358
+ dataset_type = gr.Radio(
359
+ label="Dataset type",
360
+ choices=["Generate", "CSV"],
361
+ value=default_dataset_type,
362
+ interactive=True,
363
+ )
364
+
365
+ with gr.Row():
366
+ function = gr.Textbox(
367
+ label="Function (in terms of x1 and x2)",
368
+ value=default_function,
369
+ interactive=True,
370
+ )
371
+
372
+ with gr.Row():
373
+ x1_textbox = gr.Textbox(
374
+ label="x1 range",
375
+ value=default_x1_range,
376
+ interactive=True,
377
+ )
378
+ x2_textbox = gr.Textbox(
379
+ label="x2 range",
380
+ value=default_x2_range,
381
+ interactive=True,
382
+ )
383
+ x_selection_method = gr.Radio(
384
+ label="How to select x points",
385
+ choices=["Grid", "Random"],
386
+ value=default_x_selection_method,
387
+ interactive=True,
388
+ )
389
+
390
+ with gr.Row():
391
+ sigma = gr.Number(
392
+ label="Gaussian noise standard deviation",
393
+ value=default_sigma,
394
+ interactive=True,
395
+ )
396
+ nsample = gr.Number(
397
+ label="Number of points",
398
+ value=default_num_points,
399
+ interactive=True,
400
+ )
401
+
402
+ with gr.Row():
403
+ csv_file = gr.File(
404
+ label="Upload CSV file - must have columns: (x1, x2, y)",
405
+ file_types=['.csv'],
406
+ visible=False, # function mode is default
407
+ )
408
+
409
+ with gr.Row():
410
+ has_header = gr.Checkbox(
411
+ label="CSV has header row",
412
+ value=default_has_header,
413
+ visible=False,
414
+ )
415
+
416
+ with gr.Row():
417
+ x1_col = gr.Number(
418
+ label="x1 column index (0-based)",
419
+ value=default_x1_col,
420
+ visible=False,
421
+ )
422
+ x2_col = gr.Number(
423
+ label="x2 column index (0-based)",
424
+ value=default_x2_col,
425
+ visible=False,
426
+ )
427
+ y_col = gr.Number(
428
+ label="y column index (0-based)",
429
+ value=default_y_col,
430
+ visible=False,
431
+ )
432
+
433
+ dataset_type.change(
434
+ fn=handle_dataset_type_change,
435
+ inputs=[dataset_type],
436
+ outputs=[
437
+ function,
438
+ x1_textbox,
439
+ x2_textbox,
440
+ x_selection_method,
441
+ sigma,
442
+ nsample,
443
+ csv_file,
444
+ has_header,
445
+ x1_col,
446
+ x2_col,
447
+ y_col,
448
+ ],
449
+ )
450
+
451
+ with gr.Tab("Regularization"):
452
+ with gr.Row():
453
+ loss_type_dropdown = gr.Dropdown(
454
+ label="Loss type",
455
+ choices=["l1", "l2"],
456
+ value=default_loss_type,
457
+ interactive=True,
458
+ )
459
+
460
+ with gr.Row():
461
+ regularizer_type_dropdown = gr.Dropdown(
462
+ label="Regularizer type",
463
+ choices=["l1", "l2"],
464
+ value=default_regularizer_type,
465
+ interactive=True,
466
+ )
467
+ regularizer_levels_textbox = gr.Textbox(
468
+ label="Regularization levels (comma-separated)",
469
+ value=default_reg_levels,
470
+ interactive=True,
471
+ )
472
+
473
+ with gr.Row():
474
+ w1_range_textbox = gr.Textbox(
475
+ label="w1 range (min,max)",
476
+ value=default_w1_range,
477
+ interactive=True,
478
+ )
479
+ w2_range_textbox = gr.Textbox(
480
+ label="w2 range (min,max)",
481
+ value=default_w2_range,
482
+ interactive=True,
483
+ )
484
+
485
+ resolution_slider = gr.Slider(
486
+ label="Grid resolution",
487
+ value=default_resolution,
488
+ minimum=100,
489
+ maximum=400,
490
+ step=1,
491
+ interactive=True,
492
+ )
493
+
494
+ gr.Button("Generate Plots").click(
495
+ fn=get_dataset,
496
+ inputs=[
497
+ dataset_type,
498
+ function,
499
+ x1_textbox,
500
+ x2_textbox,
501
+ x_selection_method,
502
+ sigma,
503
+ nsample,
504
+ csv_file,
505
+ has_header,
506
+ x1_col,
507
+ x2_col,
508
+ y_col,
509
+ ],
510
+ outputs=[dataset],
511
+ ).then(
512
+ fn=lambda *args: compute_plot_values(*parse_plot_settings(*args)),
513
+ inputs=[
514
+ dataset,
515
+ loss_type_dropdown,
516
+ regularizer_type_dropdown,
517
+ regularizer_levels_textbox,
518
+ w1_range_textbox,
519
+ w2_range_textbox,
520
+ resolution_slider,
521
+ ],
522
+ outputs=[plots_data],
523
+ ).then(
524
+ fn=generate_all_plots,
525
+ inputs=[dataset, plots_data],
526
+ outputs=[main_plot, data_plot, strength_plot],
527
+ )
528
+
529
+ with gr.Tab("Export"):
530
+ pass
531
+
532
+ with gr.Tab("Usage"):
533
+ pass
534
+
535
+ demo.launch(css=CSS)
536
+
537
+
538
+ if __name__ == "__main__":
539
+ launch()
dataset.py → old/dataset.py RENAMED
File without changes
regularization.py → old/regularization.py RENAMED
File without changes
requirements.txt CHANGED
@@ -1,8 +1,9 @@
 
1
  matplotlib
 
2
  numpy
3
  pandas
4
- scikit-learn
5
- mpu
6
- numexpr
7
  pillow
8
  plotly
 
 
 
1
+ cvxpy
2
  matplotlib
3
+ mpu
4
  numpy
5
  pandas
 
 
 
6
  pillow
7
  plotly
8
+ scikit-learn
9
+ sympy