Joel Woodfield commited on
Commit
c5b45d7
·
1 Parent(s): b1278c1

Use current dataset settings for setting suggested settings

Browse files
backend/src/__pycache__/logic.cpython-312.pyc CHANGED
Binary files a/backend/src/__pycache__/logic.cpython-312.pyc and b/backend/src/__pycache__/logic.cpython-312.pyc differ
 
backend/src/__pycache__/manager.cpython-312.pyc CHANGED
Binary files a/backend/src/__pycache__/manager.cpython-312.pyc and b/backend/src/__pycache__/manager.cpython-312.pyc differ
 
backend/src/logic.py CHANGED
@@ -281,6 +281,6 @@ def compute_suggested_settings(
281
 
282
  opt_norm = float(np.linalg.norm(moore_penrose, ord=2))
283
 
284
- reg_levels = [i / 4 * opt_norm for i in range(3)]
285
 
286
  return w1_range, w2_range, reg_levels
 
281
 
282
  opt_norm = float(np.linalg.norm(moore_penrose, ord=2))
283
 
284
+ reg_levels = [i / 4 * opt_norm for i in range(1, 4)]
285
 
286
  return w1_range, w2_range, reg_levels
backend/src/manager.py CHANGED
@@ -37,6 +37,36 @@ class Manager:
37
  x2_col: int,
38
  y_col: int,
39
  ) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  if dataset_type == "Generate":
41
  try:
42
  parsed_function = sympify(function)
@@ -50,7 +80,7 @@ class Manager:
50
  if method not in ("grid", "random"):
51
  raise ValueError(f"Invalid x_selection_method: {x_selection_method}")
52
 
53
- self.dataset = generate_dataset(
54
  parsed_function,
55
  x1_range,
56
  x2_range,
@@ -60,7 +90,7 @@ class Manager:
60
  elif dataset_type == "CSV":
61
  csv_path = self._resolve_csv_path(csv_file)
62
  try:
63
- self.dataset = load_dataset_from_csv(
64
  csv_path,
65
  bool(has_header),
66
  int(x1_col),
@@ -157,11 +187,37 @@ class Manager:
157
  strength_plot = self._generate_strength_plot(self.plots_data.path)
158
  return self, contour_plot, data_plot, strength_plot
159
 
160
- def handle_use_suggested_settings(self) -> tuple[Manager, str, str, str]:
161
- if self.dataset is None:
162
- raise ValueError("Dataset is not initialized")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- w1_range, w2_range, reg_levels = compute_suggested_settings(self.dataset)
165
 
166
  w1_range_input = f"{w1_range[0]:.2f}, {w1_range[1]:.2f}"
167
  w2_range_input = f"{w2_range[0]:.2f}, {w2_range[1]:.2f}"
 
37
  x2_col: int,
38
  y_col: int,
39
  ) -> None:
40
+ self.dataset = self._compute_dataset(
41
+ dataset_type,
42
+ function,
43
+ x1_range_input,
44
+ x2_range_input,
45
+ x_selection_method,
46
+ sigma,
47
+ nsample,
48
+ csv_file,
49
+ has_header,
50
+ x1_col,
51
+ x2_col,
52
+ y_col,
53
+ )
54
+
55
+ def _compute_dataset(
56
+ self,
57
+ dataset_type: str,
58
+ function: str,
59
+ x1_range_input: str,
60
+ x2_range_input: str,
61
+ x_selection_method: str,
62
+ sigma: float,
63
+ nsample: int,
64
+ csv_file: str,
65
+ has_header: bool,
66
+ x1_col: int,
67
+ x2_col: int,
68
+ y_col: int,
69
+ ) -> Dataset:
70
  if dataset_type == "Generate":
71
  try:
72
  parsed_function = sympify(function)
 
80
  if method not in ("grid", "random"):
81
  raise ValueError(f"Invalid x_selection_method: {x_selection_method}")
82
 
83
+ return generate_dataset(
84
  parsed_function,
85
  x1_range,
86
  x2_range,
 
90
  elif dataset_type == "CSV":
91
  csv_path = self._resolve_csv_path(csv_file)
92
  try:
93
+ return load_dataset_from_csv(
94
  csv_path,
95
  bool(has_header),
96
  int(x1_col),
 
187
  strength_plot = self._generate_strength_plot(self.plots_data.path)
188
  return self, contour_plot, data_plot, strength_plot
189
 
190
+ def handle_use_suggested_settings(
191
+ self,
192
+ dataset_type: str,
193
+ function: str,
194
+ x1_range_input: str,
195
+ x2_range_input: str,
196
+ x_selection_method: str,
197
+ sigma: float,
198
+ nsample: int,
199
+ csv_file: str,
200
+ has_header: bool,
201
+ x1_col: int,
202
+ x2_col: int,
203
+ y_col: int,
204
+ ) -> tuple[Manager, str, str, str]:
205
+ dataset = self._compute_dataset(
206
+ dataset_type,
207
+ function,
208
+ x1_range_input,
209
+ x2_range_input,
210
+ x_selection_method,
211
+ sigma,
212
+ nsample,
213
+ csv_file,
214
+ has_header,
215
+ x1_col,
216
+ x2_col,
217
+ y_col,
218
+ )
219
 
220
+ w1_range, w2_range, reg_levels = compute_suggested_settings(dataset)
221
 
222
  w1_range_input = f"{w1_range[0]:.2f}, {w1_range[1]:.2f}"
223
  w2_range_input = f"{w2_range[0]:.2f}, {w2_range[1]:.2f}"
frontends/gradio/main.py CHANGED
@@ -94,8 +94,33 @@ def handle_generate_plots(
94
 
95
  def handle_use_suggested_settings(
96
  manager: Manager,
 
 
 
 
 
 
 
 
 
 
 
 
97
  ) -> tuple[Manager, str, str, str]:
98
- return manager.handle_use_suggested_settings()
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
 
101
  def launch():
@@ -307,7 +332,21 @@ def launch():
307
 
308
  gr.Button("Use suggested settings").click(
309
  fn=handle_use_suggested_settings,
310
- inputs=[manager_state],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  outputs=[
312
  manager_state,
313
  w1_range_textbox,
 
94
 
95
  def handle_use_suggested_settings(
96
  manager: Manager,
97
+ dataset_type: str,
98
+ function: str,
99
+ x1_range_input: str,
100
+ x2_range_input: str,
101
+ x_selection_method: str,
102
+ sigma: float,
103
+ nsample: int,
104
+ csv_file: str,
105
+ has_header: bool,
106
+ x1_col: int,
107
+ x2_col: int,
108
+ y_col: int,
109
  ) -> tuple[Manager, str, str, str]:
110
+ return manager.handle_use_suggested_settings(
111
+ dataset_type,
112
+ function,
113
+ x1_range_input,
114
+ x2_range_input,
115
+ x_selection_method,
116
+ sigma,
117
+ nsample,
118
+ csv_file,
119
+ has_header,
120
+ x1_col,
121
+ x2_col,
122
+ y_col,
123
+ )
124
 
125
 
126
  def launch():
 
332
 
333
  gr.Button("Use suggested settings").click(
334
  fn=handle_use_suggested_settings,
335
+ inputs=[
336
+ manager_state,
337
+ dataset_type,
338
+ function,
339
+ x1_textbox,
340
+ x2_textbox,
341
+ x_selection_method,
342
+ sigma,
343
+ nsample,
344
+ csv_file,
345
+ has_header,
346
+ x1_col,
347
+ x2_col,
348
+ y_col,
349
+ ],
350
  outputs=[
351
  manager_state,
352
  w1_range_textbox,