Joel Woodfield commited on
Commit
85ac76b
·
1 Parent(s): 484caec

Add automatic settings generation

Browse files
backend/src/__pycache__/logic.cpython-312.pyc CHANGED
Binary files a/backend/src/__pycache__/logic.cpython-312.pyc and b/backend/src/__pycache__/logic.cpython-312.pyc differ
 
backend/src/__pycache__/manager.cpython-312.pyc CHANGED
Binary files a/backend/src/__pycache__/manager.cpython-312.pyc and b/backend/src/__pycache__/manager.cpython-312.pyc differ
 
backend/src/logic.py CHANGED
@@ -260,4 +260,27 @@ def compute_plot_values(
260
  reg_levels=reg_levels,
261
  unreg_solution=unreg_solution,
262
  path=path,
263
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  reg_levels=reg_levels,
261
  unreg_solution=unreg_solution,
262
  path=path,
263
+ )
264
+
265
+
266
+ def compute_suggested_settings(
267
+ dataset: Dataset
268
+ ) -> tuple[tuple[float, float], tuple[float, float], list[float]]:
269
+ x = np.stack([dataset.x1, dataset.x2], axis=1)
270
+ moore_penrose = np.linalg.pinv(x) @ np.array(dataset.y)
271
+
272
+ if np.isclose(moore_penrose, 0).all():
273
+ w1_range = (-10, 10)
274
+ w2_range = (-10, 10)
275
+ return w1_range, w2_range, []
276
+
277
+ width = np.max(np.abs(moore_penrose)) * 3
278
+
279
+ w1_range = (-width, width)
280
+ w2_range = (-width, width)
281
+
282
+ opt_norm = float(np.linalg.norm(moore_penrose, ord=2))
283
+
284
+ reg_levels = [i / 4 * opt_norm for i in range(3)]
285
+
286
+ return w1_range, w2_range, reg_levels
backend/src/manager.py CHANGED
@@ -11,6 +11,7 @@ from logic import (
11
  Dataset,
12
  PlotsData,
13
  compute_plot_values,
 
14
  generate_dataset,
15
  load_dataset_from_csv,
16
  )
@@ -156,6 +157,18 @@ class Manager:
156
  strength_plot = self._generate_strength_plot(self.plots_data.path)
157
  return self, contour_plot, data_plot, strength_plot
158
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  @staticmethod
160
  def _generate_contour_plot(plots_data: PlotsData) -> Figure:
161
  fig, ax = plt.subplots(figsize=(8, 8))
 
11
  Dataset,
12
  PlotsData,
13
  compute_plot_values,
14
+ compute_suggested_settings,
15
  generate_dataset,
16
  load_dataset_from_csv,
17
  )
 
157
  strength_plot = self._generate_strength_plot(self.plots_data.path)
158
  return self, contour_plot, data_plot, strength_plot
159
 
160
+ def handle_use_suggested_settings(self) -> tuple[Manager, str, str, str]:
161
+ if self.dataset is None:
162
+ raise ValueError("Dataset is not initialized")
163
+
164
+ w1_range, w2_range, reg_levels = compute_suggested_settings(self.dataset)
165
+
166
+ w1_range_input = f"{w1_range[0]:.2f}, {w1_range[1]:.2f}"
167
+ w2_range_input = f"{w2_range[0]:.2f}, {w2_range[1]:.2f}"
168
+ reg_levels_input = ", ".join(f"{level:.2f}" for level in reg_levels)
169
+
170
+ return self, w1_range_input, w2_range_input, reg_levels_input
171
+
172
  @staticmethod
173
  def _generate_contour_plot(plots_data: PlotsData) -> Figure:
174
  fig, ax = plt.subplots(figsize=(8, 8))
frontends/gradio/main.py CHANGED
@@ -92,6 +92,11 @@ def handle_generate_plots(
92
  resolution,
93
  )
94
 
 
 
 
 
 
95
 
96
  def launch():
97
  default_dataset_type = "Generate"
@@ -256,14 +261,14 @@ def launch():
256
  value=default_loss_type,
257
  interactive=True,
258
  )
259
-
260
- with gr.Row():
261
  regularizer_type_dropdown = gr.Dropdown(
262
  label="Regularizer type",
263
  choices=["l1", "l2"],
264
  value=default_regularizer_type,
265
  interactive=True,
266
  )
 
 
267
  regularizer_levels_textbox = gr.Textbox(
268
  label="Regularization levels (comma-separated)",
269
  value=default_reg_levels,
@@ -290,6 +295,18 @@ def launch():
290
  step=1,
291
  interactive=True,
292
  )
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  gr.Button("Regenerate Plots").click(
295
  fn=handle_generate_plots,
 
92
  resolution,
93
  )
94
 
95
+ def handle_use_suggested_settings(
96
+ manager: Manager,
97
+ ) -> tuple[Manager, str, str, str]:
98
+ return manager.handle_use_suggested_settings()
99
+
100
 
101
  def launch():
102
  default_dataset_type = "Generate"
 
261
  value=default_loss_type,
262
  interactive=True,
263
  )
 
 
264
  regularizer_type_dropdown = gr.Dropdown(
265
  label="Regularizer type",
266
  choices=["l1", "l2"],
267
  value=default_regularizer_type,
268
  interactive=True,
269
  )
270
+
271
+ with gr.Row():
272
  regularizer_levels_textbox = gr.Textbox(
273
  label="Regularization levels (comma-separated)",
274
  value=default_reg_levels,
 
295
  step=1,
296
  interactive=True,
297
  )
298
+
299
+ gr.Button("Use suggested settings").click(
300
+ fn=handle_use_suggested_settings,
301
+ inputs=[manager_state],
302
+ outputs=[
303
+ manager_state,
304
+ w1_range_textbox,
305
+ w2_range_textbox,
306
+ regularizer_levels_textbox,
307
+ ],
308
+ )
309
+
310
 
311
  gr.Button("Regenerate Plots").click(
312
  fn=handle_generate_plots,