Joel Woodfield commited on
Commit
484caec
·
1 Parent(s): d904d83

Refactor to use a backend manager

Browse files
backend/src/__pycache__/backend.cpython-312.pyc CHANGED
Binary files a/backend/src/__pycache__/backend.cpython-312.pyc and b/backend/src/__pycache__/backend.cpython-312.pyc differ
 
backend/src/__pycache__/logic.cpython-312.pyc ADDED
Binary file (12.7 kB). View file
 
backend/src/__pycache__/logic.cpython-314.pyc ADDED
Binary file (15.9 kB). View file
 
backend/src/__pycache__/manager.cpython-312.pyc ADDED
Binary file (11.3 kB). View file
 
backend/src/__pycache__/manager.cpython-314.pyc ADDED
Binary file (12.8 kB). View file
 
backend/src/{backend.py → logic.py} RENAMED
File without changes
backend/src/manager.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import matplotlib.lines as mlines
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from matplotlib.figure import Figure
7
+ from sympy import sympify
8
+
9
+ from logic import (
10
+ DataGenerationOptions,
11
+ Dataset,
12
+ PlotsData,
13
+ compute_plot_values,
14
+ generate_dataset,
15
+ load_dataset_from_csv,
16
+ )
17
+
18
+
19
+ class Manager:
20
+ def __init__(self, dataset: Dataset | None = None, plots_data: PlotsData | None = None):
21
+ self.dataset = dataset
22
+ self.plots_data = plots_data
23
+
24
+ def update_dataset(
25
+ self,
26
+ dataset_type: str,
27
+ function: str,
28
+ x1_range_input: str,
29
+ x2_range_input: str,
30
+ x_selection_method: str,
31
+ sigma: float,
32
+ nsample: int,
33
+ csv_file: str,
34
+ has_header: bool,
35
+ x1_col: int,
36
+ x2_col: int,
37
+ y_col: int,
38
+ ) -> None:
39
+ if dataset_type == "Generate":
40
+ try:
41
+ parsed_function = sympify(function)
42
+ except Exception as e:
43
+ raise ValueError(f"Invalid function {e}")
44
+
45
+ x1_range = self._parse_range(x1_range_input)
46
+ x2_range = self._parse_range(x2_range_input)
47
+
48
+ method = x_selection_method.lower()
49
+ if method not in ("grid", "random"):
50
+ raise ValueError(f"Invalid x_selection_method: {x_selection_method}")
51
+
52
+ self.dataset = generate_dataset(
53
+ parsed_function,
54
+ x1_range,
55
+ x2_range,
56
+ DataGenerationOptions(method, int(nsample), float(sigma)),
57
+ )
58
+
59
+ elif dataset_type == "CSV":
60
+ csv_path = self._resolve_csv_path(csv_file)
61
+ try:
62
+ self.dataset = load_dataset_from_csv(
63
+ csv_path,
64
+ bool(has_header),
65
+ int(x1_col),
66
+ int(x2_col),
67
+ int(y_col),
68
+ )
69
+ except Exception as e:
70
+ raise ValueError(f"Failed to load dataset from CSV: {e}")
71
+
72
+ else:
73
+ raise ValueError(f"Invalid dataset_type: {dataset_type}")
74
+
75
+
76
+ def compute_plots_data(
77
+ self,
78
+ loss_type: str,
79
+ regularizer_type: str,
80
+ reg_levels_input: str,
81
+ w1_range_input: str,
82
+ w2_range_input: str,
83
+ resolution: int,
84
+ ) -> None:
85
+ if self.dataset is None:
86
+ raise ValueError("Dataset is not initialized")
87
+
88
+ if loss_type not in ("l1", "l2"):
89
+ raise ValueError(f"Invalid loss_type: {loss_type}")
90
+ if regularizer_type not in ("l1", "l2"):
91
+ raise ValueError(f"Invalid regularizer_type: {regularizer_type}")
92
+
93
+ reg_levels = self._parse_levels(reg_levels_input)
94
+ w1_range = self._parse_range(w1_range_input)
95
+ w2_range = self._parse_range(w2_range_input)
96
+
97
+ self.plots_data = compute_plot_values(
98
+ self.dataset,
99
+ loss_type,
100
+ regularizer_type,
101
+ reg_levels,
102
+ w1_range,
103
+ w2_range,
104
+ int(resolution),
105
+ )
106
+
107
+ def handle_generate_plots(
108
+ self,
109
+ dataset_type: str,
110
+ function: str,
111
+ x1_range_input: str,
112
+ x2_range_input: str,
113
+ x_selection_method: str,
114
+ sigma: float,
115
+ nsample: int,
116
+ csv_file: str,
117
+ has_header: bool,
118
+ x1_col: int,
119
+ x2_col: int,
120
+ y_col: int,
121
+ loss_type: str,
122
+ regularizer_type: str,
123
+ reg_levels_input: str,
124
+ w1_range_input: str,
125
+ w2_range_input: str,
126
+ resolution: int,
127
+ ) -> tuple[Manager, Figure, Figure, Figure]:
128
+ self.update_dataset(
129
+ dataset_type,
130
+ function,
131
+ x1_range_input,
132
+ x2_range_input,
133
+ x_selection_method,
134
+ sigma,
135
+ nsample,
136
+ csv_file,
137
+ has_header,
138
+ x1_col,
139
+ x2_col,
140
+ y_col,
141
+ )
142
+ self.compute_plots_data(
143
+ loss_type,
144
+ regularizer_type,
145
+ reg_levels_input,
146
+ w1_range_input,
147
+ w2_range_input,
148
+ resolution,
149
+ )
150
+
151
+ if self.dataset is None or self.plots_data is None:
152
+ raise ValueError("Failed to generate plot data")
153
+
154
+ contour_plot = self._generate_contour_plot(self.plots_data)
155
+ data_plot = self._generate_data_plot(self.dataset)
156
+ strength_plot = self._generate_strength_plot(self.plots_data.path)
157
+ return self, contour_plot, data_plot, strength_plot
158
+
159
+ @staticmethod
160
+ def _generate_contour_plot(plots_data: PlotsData) -> Figure:
161
+ fig, ax = plt.subplots(figsize=(8, 8))
162
+ ax.set_xlabel("w1")
163
+ ax.set_ylabel("w2")
164
+
165
+ cmap = plt.get_cmap("viridis")
166
+ n_levels = len(plots_data.reg_levels)
167
+ if n_levels == 1:
168
+ colors = [cmap(0.5)]
169
+ else:
170
+ colors = [cmap(i / (n_levels - 1)) for i in range(n_levels)]
171
+
172
+ cs1 = ax.contour(
173
+ plots_data.W1,
174
+ plots_data.W2,
175
+ plots_data.norms,
176
+ levels=plots_data.reg_levels,
177
+ colors=colors,
178
+ linestyles="dashed",
179
+ )
180
+ ax.clabel(cs1, inline=True, fontsize=8)
181
+
182
+ cs2 = ax.contour(
183
+ plots_data.W1,
184
+ plots_data.W2,
185
+ plots_data.loss_values,
186
+ levels=plots_data.loss_levels,
187
+ colors=colors[::-1],
188
+ )
189
+ ax.clabel(cs2, inline=True, fontsize=8)
190
+
191
+ if plots_data.unreg_solution.ndim == 1:
192
+ ax.plot(
193
+ plots_data.unreg_solution[0],
194
+ plots_data.unreg_solution[1],
195
+ "bx",
196
+ markersize=5,
197
+ label="unregularized solution",
198
+ )
199
+ else:
200
+ ax.plot(
201
+ plots_data.unreg_solution[:, 0],
202
+ plots_data.unreg_solution[:, 1],
203
+ "b-",
204
+ label="unregularized solution",
205
+ )
206
+
207
+ ax.plot(plots_data.path[:, 0], plots_data.path[:, 1], "r-", label="regularization path")
208
+
209
+ handles = [
210
+ mlines.Line2D([], [], color="black", linestyle="-", label="loss"),
211
+ mlines.Line2D([], [], color="black", linestyle="--", label="regularization"),
212
+ mlines.Line2D([], [], color="red", linestyle="-", label="regularization path"),
213
+ ]
214
+ if plots_data.unreg_solution.ndim == 1:
215
+ handles.append(
216
+ mlines.Line2D([], [], color="blue", marker="x", linestyle="None", label="unregularized solution")
217
+ )
218
+ else:
219
+ handles.append(mlines.Line2D([], [], color="blue", linestyle="-", label="unregularized solution"))
220
+
221
+ ax.legend(handles=handles)
222
+ ax.grid(True)
223
+ return fig
224
+
225
+ @staticmethod
226
+ def _generate_data_plot(dataset: Dataset) -> Figure:
227
+ fig, ax = plt.subplots(figsize=(8, 8))
228
+ ax.set_xlabel("x1")
229
+ ax.set_ylabel("x2")
230
+
231
+ scatter = ax.scatter(dataset.x1, dataset.x2, c=dataset.y, cmap="viridis")
232
+ ax.grid(True)
233
+ fig.colorbar(scatter, ax=ax)
234
+ return fig
235
+
236
+ @staticmethod
237
+ def _generate_strength_plot(path: np.ndarray) -> Figure:
238
+ reg_levels = np.logspace(-4, 4, path.shape[0])
239
+
240
+ fig, ax = plt.subplots(figsize=(8, 6))
241
+ ax.set_xlabel("Regularization Strength")
242
+ ax.set_ylabel("Weight")
243
+
244
+ ax.plot(reg_levels, path[:, 0], "r-", label="w1")
245
+ ax.plot(reg_levels, path[:, 1], "b-", label="w2")
246
+ ax.set_xscale("log")
247
+ ax.legend()
248
+ ax.grid(True)
249
+ return fig
250
+
251
+ @staticmethod
252
+ def _parse_range(range_input: str) -> tuple[float, float]:
253
+ values = tuple(float(x.strip()) for x in range_input.split(","))
254
+ if len(values) != 2:
255
+ raise ValueError("Range must contain exactly two comma-separated values")
256
+ return values
257
+
258
+ @staticmethod
259
+ def _parse_levels(levels_input: str) -> list[float]:
260
+ values = [float(x.strip()) for x in levels_input.split(",")]
261
+ if not values:
262
+ raise ValueError("At least one regularization level is required")
263
+ return values
264
+
265
+ @staticmethod
266
+ def _resolve_csv_path(csv_file: str) -> str:
267
+ if csv_file is None:
268
+ raise ValueError("CSV file is required")
269
+ if isinstance(csv_file, str):
270
+ return csv_file
271
+ if isinstance(csv_file, dict) and "name" in csv_file:
272
+ return csv_file["name"]
273
+ if hasattr(csv_file, "name"):
274
+ return csv_file.name
275
+ raise ValueError("Unsupported CSV file input")
frontends/gradio/__pycache__/main.cpython-314.pyc ADDED
Binary file (11.4 kB). View file
 
frontends/gradio/main.py CHANGED
@@ -1,28 +1,17 @@
1
- import io
2
  from typing import Literal
3
 
4
  import gradio as gr
5
  from matplotlib.figure import Figure
6
- import matplotlib.pyplot as plt
7
- import matplotlib.lines as mlines
8
- import numpy as np
9
- from sympy import sympify
10
 
11
  import sys
12
  from pathlib import Path
 
13
  root_dir = Path(__file__).resolve().parent.parent.parent
14
  backend_src = root_dir / "backend" / "src"
15
  if str(backend_src) not in sys.path:
16
  sys.path.append(str(backend_src))
17
 
18
- from backend import (
19
- compute_plot_values,
20
- generate_dataset,
21
- load_dataset_from_csv,
22
- Dataset,
23
- DataGenerationOptions,
24
- PlotsData,
25
- )
26
 
27
  CSS = """
28
  .hidden-button {
@@ -31,7 +20,38 @@ CSS = """
31
  """
32
 
33
 
34
- def get_dataset(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  dataset_type: str,
36
  function: str,
37
  x1_range_input: str,
@@ -44,232 +64,35 @@ def get_dataset(
44
  x1_col: int,
45
  x2_col: int,
46
  y_col: int,
47
- ) -> Dataset:
48
- if dataset_type == "Generate":
49
- try:
50
- function = sympify(function)
51
- except Exception as e:
52
- raise ValueError(f"Invalid function: {e}")
53
-
54
- x1_range = tuple(float(x.strip()) for x in x1_range_input.split(","))
55
- x2_range = tuple(float(x.strip()) for x in x2_range_input.split(","))
56
-
57
- if (len(x1_range) != 2 or len(x2_range) != 2):
58
- raise ValueError("x1_range and x2_range must be tuples of length 2")
59
-
60
- x_selection_method = x_selection_method.lower()
61
- if x_selection_method not in ("grid", "random"):
62
- raise ValueError(f"Invalid x_selection_method: {x_selection_method}")
63
-
64
- dataset = generate_dataset(
65
- function,
66
- x1_range,
67
- x2_range,
68
- DataGenerationOptions(
69
- x_selection_method,
70
- nsample,
71
- sigma,
72
- )
73
- )
74
-
75
- elif dataset_type == "CSV":
76
- try:
77
- dataset = load_dataset_from_csv(
78
- csv_file,
79
- has_header,
80
- x1_col,
81
- x2_col,
82
- y_col,
83
- )
84
- except Exception as e:
85
- gr.Info(f"Error loading CSV: {e}")
86
- raise e
87
-
88
- else:
89
- raise ValueError(f"Invalid dataset_type: {dataset_type}")
90
-
91
- return dataset
92
-
93
-
94
- def parse_plot_settings(
95
- dataset: Dataset,
96
  loss_type: str,
97
  regularizer_type: str,
98
  reg_levels_input: str,
99
  w1_range_input: str,
100
  w2_range_input: str,
101
  resolution: int,
102
- ) -> tuple[Dataset, Literal["l1", "l2"], Literal["l1", "l2"], list[float], tuple[float, float], tuple[float, float], int]:
103
- reg_levels = [float(x.strip()) for x in reg_levels_input.split(",")]
104
- w1_range = tuple(float(x.strip()) for x in w1_range_input.split(","))
105
- w2_range = tuple(float(x.strip()) for x in w2_range_input.split(","))
106
-
107
- if loss_type not in ("l1", "l2"):
108
- raise ValueError(f"Invalid loss_type: {loss_type}")
109
- if regularizer_type not in ("l1", "l2"):
110
- raise ValueError(f"Invalid regularizer_type: {regularizer_type}")
111
-
112
- if len(w1_range) != 2 or len(w2_range) != 2:
113
- raise ValueError("w1_range and w2_range must be tuples of length 2")
114
-
115
- return (
116
- dataset,
117
  loss_type,
118
  regularizer_type,
119
- reg_levels,
120
- w1_range,
121
- w2_range,
122
  resolution,
123
  )
124
 
125
 
126
- def generate_contour_plot(
127
- W1: np.ndarray,
128
- W2: np.ndarray,
129
- losses: np.ndarray,
130
- norms: np.ndarray,
131
- loss_levels: list[float],
132
- reg_levels: list[float],
133
- unreg_solution: np.ndarray,
134
- path: np.ndarray,
135
- ) -> Figure:
136
- fig, ax = plt.subplots(figsize=(8, 8))
137
- ax.set_title("")
138
- ax.set_xlabel("w1")
139
- ax.set_ylabel("w2")
140
-
141
- cmap = plt.get_cmap("viridis")
142
- N = len(reg_levels)
143
- colors = [cmap(i / (N - 1)) for i in range(N)]
144
-
145
- # regularizer contours
146
- cs1 = ax.contour(W1, W2, norms, levels=reg_levels, colors=colors, linestyles="dashed")
147
- ax.clabel(cs1, inline=True, fontsize=8) # show contour levels
148
-
149
- # loss contours
150
- cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
151
- ax.clabel(cs2, inline=True, fontsize=8)
152
-
153
- # unregularized solution
154
- if unreg_solution.ndim == 1:
155
- ax.plot(unreg_solution[0], unreg_solution[1], "bx", markersize=5, label="unregularized solution")
156
- else:
157
- ax.plot(unreg_solution[:, 0], unreg_solution[:, 1], "b-", label="unregularized solution")
158
-
159
- ax.plot(path[:, 0], path[:, 1], "r-", label="regularization path")
160
-
161
- # legend
162
- loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
163
- reg_line = mlines.Line2D([], [], color='black', linestyle='--', label='regularization')
164
- handles = [loss_line, reg_line]
165
-
166
- path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
167
- handles.append(path_line)
168
-
169
- if unreg_solution.ndim == 1:
170
- handles.append(
171
- mlines.Line2D([], [], color='blue', marker='x', linestyle='None', label='unregularized solution')
172
- )
173
- else:
174
- handles.append(
175
- mlines.Line2D([], [], color='blue', linestyle='-', label='unregularized solution')
176
- )
177
-
178
- ax.legend(handles=handles)
179
-
180
- ax.grid(True)
181
-
182
- return fig
183
-
184
-
185
- def generate_data_plot(dataset: Dataset) -> Figure:
186
- fig, ax = plt.subplots(figsize=(8, 8))
187
- ax.set_xlabel("x1")
188
- ax.set_ylabel("x2")
189
-
190
- sc = ax.scatter(dataset.x1, dataset.x2, c=dataset.y, cmap='viridis')
191
- ax.grid(True)
192
- fig.colorbar(sc, ax=ax)
193
-
194
- return fig
195
-
196
-
197
- def generate_strength_plot(
198
- path: np.ndarray,
199
- reg_levels: np.ndarray,
200
- ):
201
- fig, ax = plt.subplots(figsize=(8, 6))
202
- ax.set_xlabel("Regularization Strength")
203
- ax.set_ylabel("Weight")
204
-
205
- ax.plot(reg_levels, path[:, 0], 'r-', label='w1')
206
- ax.plot(reg_levels, path[:, 1], 'b-', label='w2')
207
-
208
- ax.set_xscale('log')
209
- ax.legend()
210
- ax.grid(True)
211
-
212
- return fig
213
-
214
-
215
- def generate_all_plots(
216
- dataset: Dataset,
217
- plots_data: PlotsData,
218
- ) -> tuple[Figure, Figure]:
219
- contour_plot = generate_contour_plot(
220
- plots_data.W1,
221
- plots_data.W2,
222
- plots_data.loss_values,
223
- plots_data.norms,
224
- plots_data.loss_levels,
225
- plots_data.reg_levels,
226
- plots_data.unreg_solution,
227
- plots_data.path,
228
- )
229
-
230
- data_plot = generate_data_plot(
231
- dataset
232
- )
233
-
234
- strength_plot = generate_strength_plot(
235
- plots_data.path,
236
- np.logspace(-4, 4, 100),
237
- )
238
-
239
- return contour_plot, data_plot, strength_plot
240
-
241
-
242
- def handle_dataset_type_change(dataset_type: Literal["Generate", "CSV"]):
243
- if dataset_type == "Generate":
244
- return (
245
- gr.update(visible=True), # function
246
- gr.update(visible=True), # x1_textbox
247
- gr.update(visible=True), # x2_textbox
248
- gr.update(visible=True), # x_selection_method
249
- gr.update(visible=True), # sigma
250
- gr.update(visible=True), # nsample
251
- gr.update(visible=False), # csv_file
252
- gr.update(visible=False), # has_header
253
- gr.update(visible=False), # x1_col
254
- gr.update(visible=False), # x2_col
255
- gr.update(visible=False), # y_col
256
- )
257
- else: # CSV
258
- return (
259
- gr.update(visible=False), # function
260
- gr.update(visible=False), # x1_textbox
261
- gr.update(visible=False), # x2_textbox
262
- gr.update(visible=False), # x_selection_method
263
- gr.update(visible=False), # sigma
264
- gr.update(visible=False), # nsample
265
- gr.update(visible=True), # csv_file
266
- gr.update(visible=True), # has_header
267
- gr.update(visible=True), # x1_col
268
- gr.update(visible=True), # x2_col
269
- gr.update(visible=True), # y_col
270
- )
271
-
272
-
273
  def launch():
274
  default_dataset_type = "Generate"
275
 
@@ -293,71 +116,41 @@ def launch():
293
  default_w2_range = "-100, 100"
294
  default_resolution = 100
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  with gr.Blocks() as demo:
297
  gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Regularization visualizer</div>")
298
 
299
- dataset = gr.State(
300
- get_dataset(
301
- default_dataset_type,
302
- default_function,
303
- default_x1_range,
304
- default_x2_range,
305
- default_x_selection_method,
306
- default_sigma,
307
- default_num_points,
308
- default_csv_file,
309
- default_has_header,
310
- default_x1_col,
311
- default_x2_col,
312
- default_y_col,
313
- )
314
- )
315
-
316
- # wrapped PlotsData
317
- plots_data = gr.State(
318
- compute_plot_values(
319
- *parse_plot_settings(
320
- dataset.value,
321
- default_loss_type,
322
- default_regularizer_type,
323
- default_reg_levels,
324
- default_w1_range,
325
- default_w2_range,
326
- default_resolution,
327
- )
328
- )
329
- )
330
 
331
  with gr.Row():
332
  with gr.Column(scale=2):
333
  with gr.Tab("Contours"):
334
- main_plot = gr.Plot(
335
- value=generate_contour_plot(
336
- plots_data.value.W1,
337
- plots_data.value.W2,
338
- plots_data.value.loss_values,
339
- plots_data.value.norms,
340
- plots_data.value.loss_levels,
341
- plots_data.value.reg_levels,
342
- plots_data.value.unreg_solution,
343
- plots_data.value.path,
344
- )
345
- )
346
  with gr.Tab("Data"):
347
- # todo
348
- data_plot = gr.Plot(
349
- value=generate_data_plot(
350
- dataset.value
351
- )
352
- )
353
  with gr.Tab("Strength"):
354
- # todo
355
- strength_plot = gr.Plot(
356
- value=generate_strength_plot(
357
- plots_data.value.path,
358
- np.logspace(-4, 4, 100), # todo
359
- )
360
- )
361
 
362
  with gr.Column(scale=1):
363
  with gr.Tab("Data"):
@@ -371,7 +164,7 @@ def launch():
371
 
372
  with gr.Row():
373
  function = gr.Textbox(
374
- label="Function (in terms of x1 and x2)",
375
  value=default_function,
376
  interactive=True,
377
  )
@@ -401,28 +194,28 @@ def launch():
401
  interactive=True,
402
  )
403
  nsample = gr.Number(
404
- label="Number of points",
405
  value=default_num_points,
406
  interactive=True,
407
  )
408
 
409
  with gr.Row():
410
  csv_file = gr.File(
411
- label="Upload CSV file - must have columns: (x1, x2, y)",
412
- file_types=['.csv'],
413
- visible=False, # function mode is default
414
  )
415
 
416
  with gr.Row():
417
  has_header = gr.Checkbox(
418
- label="CSV has header row",
419
  value=default_has_header,
420
  visible=False,
421
  )
422
-
423
  with gr.Row():
424
  x1_col = gr.Number(
425
- label="x1 column index (0-based)",
426
  value=default_x1_col,
427
  visible=False,
428
  )
@@ -436,7 +229,7 @@ def launch():
436
  value=default_y_col,
437
  visible=False,
438
  )
439
-
440
  dataset_type.change(
441
  fn=handle_dataset_type_change,
442
  inputs=[dataset_type],
@@ -499,12 +292,9 @@ def launch():
499
  )
500
 
501
  gr.Button("Regenerate Plots").click(
502
- fn=lambda: (gr.Plot(), gr.Plot(), gr.Plot()),
503
- inputs=[],
504
- outputs=[main_plot, data_plot, strength_plot],
505
- ).then(
506
- fn=get_dataset,
507
  inputs=[
 
508
  dataset_type,
509
  function,
510
  x1_textbox,
@@ -517,12 +307,6 @@ def launch():
517
  x1_col,
518
  x2_col,
519
  y_col,
520
- ],
521
- outputs=[dataset],
522
- ).then(
523
- fn=lambda *args: compute_plot_values(*parse_plot_settings(*args)),
524
- inputs=[
525
- dataset,
526
  loss_type_dropdown,
527
  regularizer_type_dropdown,
528
  regularizer_levels_textbox,
@@ -530,11 +314,7 @@ def launch():
530
  w2_range_textbox,
531
  resolution_slider,
532
  ],
533
- outputs=[plots_data],
534
- ).then(
535
- fn=generate_all_plots,
536
- inputs=[dataset, plots_data],
537
- outputs=[main_plot, data_plot, strength_plot],
538
  )
539
 
540
  demo.launch(css=CSS)
 
 
1
  from typing import Literal
2
 
3
  import gradio as gr
4
  from matplotlib.figure import Figure
 
 
 
 
5
 
6
  import sys
7
  from pathlib import Path
8
+
9
  root_dir = Path(__file__).resolve().parent.parent.parent
10
  backend_src = root_dir / "backend" / "src"
11
  if str(backend_src) not in sys.path:
12
  sys.path.append(str(backend_src))
13
 
14
+ from manager import Manager
 
 
 
 
 
 
 
15
 
16
  CSS = """
17
  .hidden-button {
 
20
  """
21
 
22
 
23
+ def handle_dataset_type_change(dataset_type: Literal["Generate", "CSV"]):
24
+ if dataset_type == "Generate":
25
+ return (
26
+ gr.update(visible=True),
27
+ gr.update(visible=True),
28
+ gr.update(visible=True),
29
+ gr.update(visible=True),
30
+ gr.update(visible=True),
31
+ gr.update(visible=True),
32
+ gr.update(visible=False),
33
+ gr.update(visible=False),
34
+ gr.update(visible=False),
35
+ gr.update(visible=False),
36
+ gr.update(visible=False),
37
+ )
38
+ return (
39
+ gr.update(visible=False),
40
+ gr.update(visible=False),
41
+ gr.update(visible=False),
42
+ gr.update(visible=False),
43
+ gr.update(visible=False),
44
+ gr.update(visible=False),
45
+ gr.update(visible=True),
46
+ gr.update(visible=True),
47
+ gr.update(visible=True),
48
+ gr.update(visible=True),
49
+ gr.update(visible=True),
50
+ )
51
+
52
+
53
+ def handle_generate_plots(
54
+ manager: Manager,
55
  dataset_type: str,
56
  function: str,
57
  x1_range_input: str,
 
64
  x1_col: int,
65
  x2_col: int,
66
  y_col: int,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  loss_type: str,
68
  regularizer_type: str,
69
  reg_levels_input: str,
70
  w1_range_input: str,
71
  w2_range_input: str,
72
  resolution: int,
73
+ ) -> tuple[Manager, Figure, Figure, Figure]:
74
+ return manager.handle_generate_plots(
75
+ dataset_type,
76
+ function,
77
+ x1_range_input,
78
+ x2_range_input,
79
+ x_selection_method,
80
+ sigma,
81
+ nsample,
82
+ csv_file,
83
+ has_header,
84
+ x1_col,
85
+ x2_col,
86
+ y_col,
 
87
  loss_type,
88
  regularizer_type,
89
+ reg_levels_input,
90
+ w1_range_input,
91
+ w2_range_input,
92
  resolution,
93
  )
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def launch():
97
  default_dataset_type = "Generate"
98
 
 
116
  default_w2_range = "-100, 100"
117
  default_resolution = 100
118
 
119
+ manager = Manager()
120
+ manager, default_contour_plot, default_data_plot, default_strength_plot = manager.handle_generate_plots(
121
+ default_dataset_type,
122
+ default_function,
123
+ default_x1_range,
124
+ default_x2_range,
125
+ default_x_selection_method,
126
+ default_sigma,
127
+ default_num_points,
128
+ default_csv_file,
129
+ default_has_header,
130
+ default_x1_col,
131
+ default_x2_col,
132
+ default_y_col,
133
+ default_loss_type,
134
+ default_regularizer_type,
135
+ default_reg_levels,
136
+ default_w1_range,
137
+ default_w2_range,
138
+ default_resolution,
139
+ )
140
+
141
  with gr.Blocks() as demo:
142
  gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Regularization visualizer</div>")
143
 
144
+ manager_state = gr.State(manager)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  with gr.Row():
147
  with gr.Column(scale=2):
148
  with gr.Tab("Contours"):
149
+ main_plot = gr.Plot(value=default_contour_plot)
 
 
 
 
 
 
 
 
 
 
 
150
  with gr.Tab("Data"):
151
+ data_plot = gr.Plot(value=default_data_plot)
 
 
 
 
 
152
  with gr.Tab("Strength"):
153
+ strength_plot = gr.Plot(value=default_strength_plot)
 
 
 
 
 
 
154
 
155
  with gr.Column(scale=1):
156
  with gr.Tab("Data"):
 
164
 
165
  with gr.Row():
166
  function = gr.Textbox(
167
+ label="Function (in terms of x1 and x2)",
168
  value=default_function,
169
  interactive=True,
170
  )
 
194
  interactive=True,
195
  )
196
  nsample = gr.Number(
197
+ label="Number of points",
198
  value=default_num_points,
199
  interactive=True,
200
  )
201
 
202
  with gr.Row():
203
  csv_file = gr.File(
204
+ label="Upload CSV file - must have columns: (x1, x2, y)",
205
+ file_types=[".csv"],
206
+ visible=False,
207
  )
208
 
209
  with gr.Row():
210
  has_header = gr.Checkbox(
211
+ label="CSV has header row",
212
  value=default_has_header,
213
  visible=False,
214
  )
215
+
216
  with gr.Row():
217
  x1_col = gr.Number(
218
+ label="x1 column index (0-based)",
219
  value=default_x1_col,
220
  visible=False,
221
  )
 
229
  value=default_y_col,
230
  visible=False,
231
  )
232
+
233
  dataset_type.change(
234
  fn=handle_dataset_type_change,
235
  inputs=[dataset_type],
 
292
  )
293
 
294
  gr.Button("Regenerate Plots").click(
295
+ fn=handle_generate_plots,
 
 
 
 
296
  inputs=[
297
+ manager_state,
298
  dataset_type,
299
  function,
300
  x1_textbox,
 
307
  x1_col,
308
  x2_col,
309
  y_col,
 
 
 
 
 
 
310
  loss_type_dropdown,
311
  regularizer_type_dropdown,
312
  regularizer_levels_textbox,
 
314
  w2_range_textbox,
315
  resolution_slider,
316
  ],
317
+ outputs=[manager_state, main_plot, data_plot, strength_plot],
 
 
 
 
318
  )
319
 
320
  demo.launch(css=CSS)