Spaces:
Sleeping
Sleeping
Joel Woodfield commited on
Commit ·
85ac76b
1
Parent(s): 484caec
Add automatic settings generation
Browse files
backend/src/__pycache__/logic.cpython-312.pyc
CHANGED
|
Binary files a/backend/src/__pycache__/logic.cpython-312.pyc and b/backend/src/__pycache__/logic.cpython-312.pyc differ
|
|
|
backend/src/__pycache__/manager.cpython-312.pyc
CHANGED
|
Binary files a/backend/src/__pycache__/manager.cpython-312.pyc and b/backend/src/__pycache__/manager.cpython-312.pyc differ
|
|
|
backend/src/logic.py
CHANGED
|
@@ -260,4 +260,27 @@ def compute_plot_values(
|
|
| 260 |
reg_levels=reg_levels,
|
| 261 |
unreg_solution=unreg_solution,
|
| 262 |
path=path,
|
| 263 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
reg_levels=reg_levels,
|
| 261 |
unreg_solution=unreg_solution,
|
| 262 |
path=path,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def compute_suggested_settings(
|
| 267 |
+
dataset: Dataset
|
| 268 |
+
) -> tuple[tuple[float, float], tuple[float, float], list[float]]:
|
| 269 |
+
x = np.stack([dataset.x1, dataset.x2], axis=1)
|
| 270 |
+
moore_penrose = np.linalg.pinv(x) @ np.array(dataset.y)
|
| 271 |
+
|
| 272 |
+
if np.isclose(moore_penrose, 0).all():
|
| 273 |
+
w1_range = (-10, 10)
|
| 274 |
+
w2_range = (-10, 10)
|
| 275 |
+
return w1_range, w2_range, []
|
| 276 |
+
|
| 277 |
+
width = np.max(np.abs(moore_penrose)) * 3
|
| 278 |
+
|
| 279 |
+
w1_range = (-width, width)
|
| 280 |
+
w2_range = (-width, width)
|
| 281 |
+
|
| 282 |
+
opt_norm = float(np.linalg.norm(moore_penrose, ord=2))
|
| 283 |
+
|
| 284 |
+
reg_levels = [i / 4 * opt_norm for i in range(3)]
|
| 285 |
+
|
| 286 |
+
return w1_range, w2_range, reg_levels
|
backend/src/manager.py
CHANGED
|
@@ -11,6 +11,7 @@ from logic import (
|
|
| 11 |
Dataset,
|
| 12 |
PlotsData,
|
| 13 |
compute_plot_values,
|
|
|
|
| 14 |
generate_dataset,
|
| 15 |
load_dataset_from_csv,
|
| 16 |
)
|
|
@@ -156,6 +157,18 @@ class Manager:
|
|
| 156 |
strength_plot = self._generate_strength_plot(self.plots_data.path)
|
| 157 |
return self, contour_plot, data_plot, strength_plot
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
@staticmethod
|
| 160 |
def _generate_contour_plot(plots_data: PlotsData) -> Figure:
|
| 161 |
fig, ax = plt.subplots(figsize=(8, 8))
|
|
|
|
| 11 |
Dataset,
|
| 12 |
PlotsData,
|
| 13 |
compute_plot_values,
|
| 14 |
+
compute_suggested_settings,
|
| 15 |
generate_dataset,
|
| 16 |
load_dataset_from_csv,
|
| 17 |
)
|
|
|
|
| 157 |
strength_plot = self._generate_strength_plot(self.plots_data.path)
|
| 158 |
return self, contour_plot, data_plot, strength_plot
|
| 159 |
|
| 160 |
+
def handle_use_suggested_settings(self) -> tuple[Manager, str, str, str]:
|
| 161 |
+
if self.dataset is None:
|
| 162 |
+
raise ValueError("Dataset is not initialized")
|
| 163 |
+
|
| 164 |
+
w1_range, w2_range, reg_levels = compute_suggested_settings(self.dataset)
|
| 165 |
+
|
| 166 |
+
w1_range_input = f"{w1_range[0]:.2f}, {w1_range[1]:.2f}"
|
| 167 |
+
w2_range_input = f"{w2_range[0]:.2f}, {w2_range[1]:.2f}"
|
| 168 |
+
reg_levels_input = ", ".join(f"{level:.2f}" for level in reg_levels)
|
| 169 |
+
|
| 170 |
+
return self, w1_range_input, w2_range_input, reg_levels_input
|
| 171 |
+
|
| 172 |
@staticmethod
|
| 173 |
def _generate_contour_plot(plots_data: PlotsData) -> Figure:
|
| 174 |
fig, ax = plt.subplots(figsize=(8, 8))
|
frontends/gradio/main.py
CHANGED
|
@@ -92,6 +92,11 @@ def handle_generate_plots(
|
|
| 92 |
resolution,
|
| 93 |
)
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
def launch():
|
| 97 |
default_dataset_type = "Generate"
|
|
@@ -256,14 +261,14 @@ def launch():
|
|
| 256 |
value=default_loss_type,
|
| 257 |
interactive=True,
|
| 258 |
)
|
| 259 |
-
|
| 260 |
-
with gr.Row():
|
| 261 |
regularizer_type_dropdown = gr.Dropdown(
|
| 262 |
label="Regularizer type",
|
| 263 |
choices=["l1", "l2"],
|
| 264 |
value=default_regularizer_type,
|
| 265 |
interactive=True,
|
| 266 |
)
|
|
|
|
|
|
|
| 267 |
regularizer_levels_textbox = gr.Textbox(
|
| 268 |
label="Regularization levels (comma-separated)",
|
| 269 |
value=default_reg_levels,
|
|
@@ -290,6 +295,18 @@ def launch():
|
|
| 290 |
step=1,
|
| 291 |
interactive=True,
|
| 292 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
gr.Button("Regenerate Plots").click(
|
| 295 |
fn=handle_generate_plots,
|
|
|
|
| 92 |
resolution,
|
| 93 |
)
|
| 94 |
|
| 95 |
+
def handle_use_suggested_settings(
|
| 96 |
+
manager: Manager,
|
| 97 |
+
) -> tuple[Manager, str, str, str]:
|
| 98 |
+
return manager.handle_use_suggested_settings()
|
| 99 |
+
|
| 100 |
|
| 101 |
def launch():
|
| 102 |
default_dataset_type = "Generate"
|
|
|
|
| 261 |
value=default_loss_type,
|
| 262 |
interactive=True,
|
| 263 |
)
|
|
|
|
|
|
|
| 264 |
regularizer_type_dropdown = gr.Dropdown(
|
| 265 |
label="Regularizer type",
|
| 266 |
choices=["l1", "l2"],
|
| 267 |
value=default_regularizer_type,
|
| 268 |
interactive=True,
|
| 269 |
)
|
| 270 |
+
|
| 271 |
+
with gr.Row():
|
| 272 |
regularizer_levels_textbox = gr.Textbox(
|
| 273 |
label="Regularization levels (comma-separated)",
|
| 274 |
value=default_reg_levels,
|
|
|
|
| 295 |
step=1,
|
| 296 |
interactive=True,
|
| 297 |
)
|
| 298 |
+
|
| 299 |
+
gr.Button("Use suggested settings").click(
|
| 300 |
+
fn=handle_use_suggested_settings,
|
| 301 |
+
inputs=[manager_state],
|
| 302 |
+
outputs=[
|
| 303 |
+
manager_state,
|
| 304 |
+
w1_range_textbox,
|
| 305 |
+
w2_range_textbox,
|
| 306 |
+
regularizer_levels_textbox,
|
| 307 |
+
],
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
|
| 311 |
gr.Button("Regenerate Plots").click(
|
| 312 |
fn=handle_generate_plots,
|