joel-woodfield commited on
Commit
19aa7d3
·
1 Parent(s): 13fd633

Automatically set the suggested settings

Browse files
backend/src/__pycache__/logic.cpython-312.pyc CHANGED
Binary files a/backend/src/__pycache__/logic.cpython-312.pyc and b/backend/src/__pycache__/logic.cpython-312.pyc differ
 
backend/src/__pycache__/manager.cpython-312.pyc CHANGED
Binary files a/backend/src/__pycache__/manager.cpython-312.pyc and b/backend/src/__pycache__/manager.cpython-312.pyc differ
 
backend/src/manager.py CHANGED
@@ -152,9 +152,6 @@ class Manager:
152
  self,
153
  loss_type: str,
154
  regularizer_type: str,
155
- reg_levels_input: str,
156
- w1_range_input: str,
157
- w2_range_input: str,
158
  resolution: int,
159
  ) -> None:
160
  if self.dataset is None:
@@ -165,19 +162,7 @@ class Manager:
165
  if regularizer_type not in ("l1", "l2"):
166
  raise ValueError(f"Invalid regularizer_type: {regularizer_type}")
167
 
168
- try:
169
- reg_levels = self._parse_levels(reg_levels_input)
170
- except Exception as e:
171
- raise ValueError(f"Invalid regularization levels: {e}")
172
-
173
- try:
174
- w1_range = self._parse_range(w1_range_input)
175
- except Exception as e:
176
- raise ValueError(f"Invalid w1 range: {e}")
177
- try:
178
- w2_range = self._parse_range(w2_range_input)
179
- except Exception as e:
180
- raise ValueError(f"Invalid w2 range: {e}")
181
 
182
  self.plots_data = compute_plot_values(
183
  self.dataset,
@@ -205,9 +190,6 @@ class Manager:
205
  y_col: int,
206
  loss_type: str,
207
  regularizer_type: str,
208
- reg_levels_input: str,
209
- w1_range_input: str,
210
- w2_range_input: str,
211
  resolution: int,
212
  ) -> tuple[Manager, Figure, Figure, Figure]:
213
  self.update_dataset(
@@ -224,12 +206,10 @@ class Manager:
224
  x2_col,
225
  y_col,
226
  )
 
227
  self.compute_plots_data(
228
  loss_type,
229
  regularizer_type,
230
- reg_levels_input,
231
- w1_range_input,
232
- w2_range_input,
233
  resolution,
234
  )
235
 
@@ -241,44 +221,6 @@ class Manager:
241
  strength_plot = self._generate_strength_plot(self.plots_data.path)
242
  return self, contour_plot, data_plot, strength_plot
243
 
244
- def handle_use_suggested_settings(
245
- self,
246
- dataset_type: str,
247
- function: str,
248
- x1_range_input: str,
249
- x2_range_input: str,
250
- x_selection_method: str,
251
- sigma: float,
252
- nsample: int,
253
- csv_file: str,
254
- has_header: bool,
255
- x1_col: int,
256
- x2_col: int,
257
- y_col: int,
258
- ) -> tuple[Manager, str, str, str]:
259
- dataset = self._compute_dataset(
260
- dataset_type,
261
- function,
262
- x1_range_input,
263
- x2_range_input,
264
- x_selection_method,
265
- sigma,
266
- nsample,
267
- csv_file,
268
- has_header,
269
- x1_col,
270
- x2_col,
271
- y_col,
272
- )
273
-
274
- w1_range, w2_range, reg_levels = compute_suggested_settings(dataset)
275
-
276
- w1_range_input = f"{w1_range[0]:.2f}, {w1_range[1]:.2f}"
277
- w2_range_input = f"{w2_range[0]:.2f}, {w2_range[1]:.2f}"
278
- reg_levels_input = ", ".join(f"{level:.2f}" for level in reg_levels)
279
-
280
- return self, w1_range_input, w2_range_input, reg_levels_input
281
-
282
  @staticmethod
283
  def _generate_contour_plot(plots_data: PlotsData) -> Figure:
284
  fig, ax = plt.subplots(figsize=(8, 8))
 
152
  self,
153
  loss_type: str,
154
  regularizer_type: str,
 
 
 
155
  resolution: int,
156
  ) -> None:
157
  if self.dataset is None:
 
162
  if regularizer_type not in ("l1", "l2"):
163
  raise ValueError(f"Invalid regularizer_type: {regularizer_type}")
164
 
165
+ w1_range, w2_range, reg_levels = compute_suggested_settings(self.dataset)
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  self.plots_data = compute_plot_values(
168
  self.dataset,
 
190
  y_col: int,
191
  loss_type: str,
192
  regularizer_type: str,
 
 
 
193
  resolution: int,
194
  ) -> tuple[Manager, Figure, Figure, Figure]:
195
  self.update_dataset(
 
206
  x2_col,
207
  y_col,
208
  )
209
+
210
  self.compute_plots_data(
211
  loss_type,
212
  regularizer_type,
 
 
 
213
  resolution,
214
  )
215
 
 
221
  strength_plot = self._generate_strength_plot(self.plots_data.path)
222
  return self, contour_plot, data_plot, strength_plot
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  @staticmethod
225
  def _generate_contour_plot(plots_data: PlotsData) -> Figure:
226
  fig, ax = plt.subplots(figsize=(8, 8))
frontends/gradio/main.py CHANGED
@@ -66,9 +66,6 @@ def handle_generate_plots(
66
  y_col: int,
67
  loss_type: str,
68
  regularizer_type: str,
69
- reg_levels_input: str,
70
- w1_range_input: str,
71
- w2_range_input: str,
72
  resolution: int,
73
  ) -> tuple[Manager, Figure, Figure, Figure]:
74
  try:
@@ -87,49 +84,12 @@ def handle_generate_plots(
87
  y_col,
88
  loss_type,
89
  regularizer_type,
90
- reg_levels_input,
91
- w1_range_input,
92
- w2_range_input,
93
  resolution,
94
  )
95
  except Exception as e:
96
  raise gr.Error("Error generating plots: " + str(e))
97
 
98
 
99
- def handle_use_suggested_settings(
100
- manager: Manager,
101
- dataset_type: str,
102
- function: str,
103
- x1_range_input: str,
104
- x2_range_input: str,
105
- x_selection_method: str,
106
- sigma: float,
107
- nsample: int,
108
- csv_file: str,
109
- has_header: bool,
110
- x1_col: int,
111
- x2_col: int,
112
- y_col: int,
113
- ) -> tuple[Manager, str, str, str]:
114
- try:
115
- return manager.handle_use_suggested_settings(
116
- dataset_type,
117
- function,
118
- x1_range_input,
119
- x2_range_input,
120
- x_selection_method,
121
- sigma,
122
- nsample,
123
- csv_file,
124
- has_header,
125
- x1_col,
126
- x2_col,
127
- y_col,
128
- )
129
- except Exception as e:
130
- raise gr.Error("Error computing suggested settings: " + str(e))
131
-
132
-
133
  def launch():
134
  default_dataset_type = "Generate"
135
 
@@ -148,9 +108,6 @@ def launch():
148
 
149
  default_loss_type = "l2"
150
  default_regularizer_type = "l2"
151
- default_reg_levels = "10, 20, 30"
152
- default_w1_range = "-100, 100"
153
- default_w2_range = "-100, 100"
154
  default_resolution = 100
155
 
156
  manager = Manager()
@@ -169,9 +126,6 @@ def launch():
169
  default_y_col,
170
  default_loss_type,
171
  default_regularizer_type,
172
- default_reg_levels,
173
- default_w1_range,
174
- default_w2_range,
175
  default_resolution,
176
  )
177
 
@@ -309,25 +263,6 @@ def launch():
309
  interactive=True,
310
  )
311
 
312
- with gr.Row():
313
- regularizer_levels_textbox = gr.Textbox(
314
- label="Regularization levels (comma-separated)",
315
- value=default_reg_levels,
316
- interactive=True,
317
- )
318
-
319
- with gr.Row():
320
- w1_range_textbox = gr.Textbox(
321
- label="w1 range (min,max)",
322
- value=default_w1_range,
323
- interactive=True,
324
- )
325
- w2_range_textbox = gr.Textbox(
326
- label="w2 range (min,max)",
327
- value=default_w2_range,
328
- interactive=True,
329
- )
330
-
331
  resolution_slider = gr.Slider(
332
  label="Grid resolution",
333
  value=default_resolution,
@@ -337,32 +272,6 @@ def launch():
337
  interactive=True,
338
  )
339
 
340
- gr.Button("Use suggested settings").click(
341
- fn=handle_use_suggested_settings,
342
- inputs=[
343
- manager_state,
344
- dataset_type,
345
- function,
346
- x1_textbox,
347
- x2_textbox,
348
- x_selection_method,
349
- sigma,
350
- nsample,
351
- csv_file,
352
- has_header,
353
- x1_col,
354
- x2_col,
355
- y_col,
356
- ],
357
- outputs=[
358
- manager_state,
359
- w1_range_textbox,
360
- w2_range_textbox,
361
- regularizer_levels_textbox,
362
- ],
363
- )
364
-
365
-
366
  gr.Button("Regenerate Plots").click(
367
  fn=handle_generate_plots,
368
  inputs=[
@@ -381,9 +290,6 @@ def launch():
381
  y_col,
382
  loss_type_dropdown,
383
  regularizer_type_dropdown,
384
- regularizer_levels_textbox,
385
- w1_range_textbox,
386
- w2_range_textbox,
387
  resolution_slider,
388
  ],
389
  outputs=[manager_state, main_plot, data_plot, strength_plot],
 
66
  y_col: int,
67
  loss_type: str,
68
  regularizer_type: str,
 
 
 
69
  resolution: int,
70
  ) -> tuple[Manager, Figure, Figure, Figure]:
71
  try:
 
84
  y_col,
85
  loss_type,
86
  regularizer_type,
 
 
 
87
  resolution,
88
  )
89
  except Exception as e:
90
  raise gr.Error("Error generating plots: " + str(e))
91
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def launch():
94
  default_dataset_type = "Generate"
95
 
 
108
 
109
  default_loss_type = "l2"
110
  default_regularizer_type = "l2"
 
 
 
111
  default_resolution = 100
112
 
113
  manager = Manager()
 
126
  default_y_col,
127
  default_loss_type,
128
  default_regularizer_type,
 
 
 
129
  default_resolution,
130
  )
131
 
 
263
  interactive=True,
264
  )
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  resolution_slider = gr.Slider(
267
  label="Grid resolution",
268
  value=default_resolution,
 
272
  interactive=True,
273
  )
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  gr.Button("Regenerate Plots").click(
276
  fn=handle_generate_plots,
277
  inputs=[
 
290
  y_col,
291
  loss_type_dropdown,
292
  regularizer_type_dropdown,
 
 
 
293
  resolution_slider,
294
  ],
295
  outputs=[manager_state, main_plot, data_plot, strength_plot],