joel-woodfield commited on
Commit
cff92ea
·
2 Parent(s): ace09ad3e26aca

Merge branch 'main' of hf.co:spaces/elvis-hf/mlp_visualizer

Browse files
README.md CHANGED
@@ -5,8 +5,8 @@ colorFrom: yellow
5
  colorTo: gray
6
  sdk: gradio
7
  app_file: frontends/gradio/main.py
8
- sdk_version: 6.3.0
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
5
  colorTo: gray
6
  sdk: gradio
7
  app_file: frontends/gradio/main.py
8
+ sdk_version: 6.5.1
9
  pinned: false
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
backend/src/__pycache__/logic.cpython-310.pyc ADDED
Binary file (4.91 kB). View file
 
backend/src/__pycache__/logic.cpython-314.pyc ADDED
Binary file (11 kB). View file
 
backend/src/__pycache__/manager.cpython-310.pyc ADDED
Binary file (5.05 kB). View file
 
backend/src/__pycache__/manager.cpython-314.pyc ADDED
Binary file (10.6 kB). View file
 
backend/src/manager.py CHANGED
@@ -37,6 +37,7 @@ class OptimizerOptions:
37
  class Manager:
38
  def __init__(self) -> None:
39
  self._dataset: Dataset | None = None
 
40
 
41
  self._architecture: str | None = None
42
  self._model: nn.Module | None = None
@@ -188,7 +189,7 @@ class Manager:
188
  else:
189
  test_dataset = self._true_dataset
190
 
191
- if test_dataset is not None and self._model is not None:
192
  test_predictions = generate_test_predictions(
193
  test_dataset,
194
  self._model,
 
37
  class Manager:
38
  def __init__(self) -> None:
39
  self._dataset: Dataset | None = None
40
+ self._true_dataset: Dataset | None = None
41
 
42
  self._architecture: str | None = None
43
  self._model: nn.Module | None = None
 
189
  else:
190
  test_dataset = self._true_dataset
191
 
192
+ if test_dataset.x and test_dataset.y and self._model is not None:
193
  test_predictions = generate_test_predictions(
194
  test_dataset,
195
  self._model,
frontends/gradio/__pycache__/main.cpython-314.pyc ADDED
Binary file (17.2 kB). View file
 
frontends/gradio/main.py CHANGED
@@ -48,7 +48,7 @@ def handle_set_dataset(
48
  has_header: bool,
49
  xcol: int,
50
  ycol: int,
51
- ) -> PlotData:
52
  options = {
53
  "dataset_type": dataset_type,
54
  "function": function,
@@ -62,17 +62,16 @@ def handle_set_dataset(
62
  "xcol": xcol,
63
  "ycol": ycol,
64
  }
65
- plot_data = manager.handle_set_dataset(options)
66
-
67
- return plot_data
68
 
69
 
70
  def handle_set_architecture(
71
  manager: Manager,
72
  architecture_str: str,
73
- ) -> PlotData:
74
- plot_data = manager.handle_set_architecture(architecture_str)
75
- return plot_data
76
 
77
 
78
  def handle_set_optimizer(
@@ -84,7 +83,7 @@ def handle_set_optimizer(
84
  momentum: float,
85
  weight_decay: float,
86
  batch_size: int,
87
- ) -> PlotData:
88
  options = {
89
  "optimizer_type": optimizer,
90
  "learning_rate": learning_rate,
@@ -94,23 +93,24 @@ def handle_set_optimizer(
94
  "weight_decay": weight_decay,
95
  "batch_size": batch_size,
96
  }
97
- plot_data = manager.handle_set_optimizer(options)
98
- return plot_data
 
99
 
100
 
101
  def handle_train_step(
102
  manager: Manager,
103
  step_increment: int,
104
- ) -> PlotData:
105
- plot_data = manager.handle_train_step(step_increment)
106
- return plot_data
107
 
108
 
109
  def handle_reset_model(
110
  manager: Manager,
111
- ) -> PlotData:
112
- plot_data = manager.handle_reset_model()
113
- return plot_data
114
 
115
 
116
  def generate_plot(data: PlotData) -> Figure:
@@ -224,43 +224,40 @@ def launch():
224
  gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>MLP visualizer</div>")
225
 
226
  manager = gr.State(Manager())
227
- handle_set_dataset(
228
- manager.value,
229
- default_dataset_type,
230
- default_function,
231
- default_data_xmin,
232
- default_data_xmax,
233
- default_sigma,
234
- default_nsample,
235
- default_sample_method,
236
- default_csv_file,
237
- default_has_header,
238
- default_xcol,
239
- default_ycol,
 
240
  )
241
- handle_set_architecture(
242
- manager.value,
243
  ARCHITECTURE_PRESETS[default_architecture_preset],
244
  )
245
- handle_set_optimizer(
246
- manager.value,
247
- default_optimizer,
248
- default_learning_rate,
249
- default_beta1,
250
- default_beta2,
251
- default_momentum,
252
- default_weight_decay,
253
- default_batch_size,
254
- )
255
-
256
- plot_data = gr.State(
257
- manager.value.get_plot_data()
258
  )
259
 
260
  with gr.Row():
261
  with gr.Column(scale=2):
262
  plot = gr.Plot(
263
- value=generate_plot(plot_data.value),
264
  )
265
 
266
  with gr.Tab("Data"):
@@ -375,10 +372,6 @@ def launch():
375
  xcol,
376
  ycol,
377
  ],
378
- outputs=[plot_data],
379
- ).then(
380
- fn=generate_plot,
381
- inputs=[plot_data],
382
  outputs=[plot],
383
  )
384
 
@@ -418,10 +411,6 @@ def launch():
418
  manager,
419
  architecture,
420
  ],
421
- outputs=[plot_data],
422
- ).then(
423
- fn=generate_plot,
424
- inputs=[plot_data],
425
  outputs=[plot],
426
  )
427
 
@@ -486,10 +475,6 @@ def launch():
486
 
487
  update_optimizer_button = gr.Button("Update Optimizer")
488
  update_optimizer_button.click(
489
- fn=manager.value.handle_reset_model,
490
- inputs=[],
491
- outputs=[plot_data],
492
- ).then(
493
  fn=handle_set_optimizer,
494
  inputs=[
495
  manager,
@@ -501,10 +486,6 @@ def launch():
501
  weight_decay,
502
  batch_size,
503
  ],
504
- outputs=[plot_data],
505
- ).then(
506
- fn=generate_plot,
507
- inputs=[plot_data],
508
  outputs=[plot],
509
  )
510
 
@@ -518,10 +499,6 @@ def launch():
518
  train_button.click(
519
  fn=handle_train_step,
520
  inputs=[manager, step_increment],
521
- outputs=[plot_data],
522
- ).then(
523
- fn=generate_plot,
524
- inputs=[plot_data],
525
  outputs=[plot],
526
  )
527
 
@@ -529,10 +506,6 @@ def launch():
529
  reset_button.click(
530
  fn=handle_reset_model,
531
  inputs=[manager],
532
- outputs=[plot_data],
533
- ).then(
534
- fn=generate_plot,
535
- inputs=[plot_data],
536
  outputs=[plot],
537
  )
538
 
 
48
  has_header: bool,
49
  xcol: int,
50
  ycol: int,
51
+ ) -> Figure:
52
  options = {
53
  "dataset_type": dataset_type,
54
  "function": function,
 
62
  "xcol": xcol,
63
  "ycol": ycol,
64
  }
65
+ manager.handle_set_dataset(options)
66
+ return generate_plot(manager.get_plot_data())
 
67
 
68
 
69
  def handle_set_architecture(
70
  manager: Manager,
71
  architecture_str: str,
72
+ ) -> Figure:
73
+ manager.handle_set_architecture(architecture_str)
74
+ return generate_plot(manager.get_plot_data())
75
 
76
 
77
  def handle_set_optimizer(
 
83
  momentum: float,
84
  weight_decay: float,
85
  batch_size: int,
86
+ ) -> Figure:
87
  options = {
88
  "optimizer_type": optimizer,
89
  "learning_rate": learning_rate,
 
93
  "weight_decay": weight_decay,
94
  "batch_size": batch_size,
95
  }
96
+ manager.handle_reset_model()
97
+ manager.handle_set_optimizer(options)
98
+ return generate_plot(manager.get_plot_data())
99
 
100
 
101
  def handle_train_step(
102
  manager: Manager,
103
  step_increment: int,
104
+ ) -> Figure:
105
+ manager.handle_train_step(step_increment)
106
+ return generate_plot(manager.get_plot_data())
107
 
108
 
109
  def handle_reset_model(
110
  manager: Manager,
111
+ ) -> Figure:
112
+ manager.handle_reset_model()
113
+ return generate_plot(manager.get_plot_data())
114
 
115
 
116
  def generate_plot(data: PlotData) -> Figure:
 
224
  gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>MLP visualizer</div>")
225
 
226
  manager = gr.State(Manager())
227
+ manager.value.handle_set_dataset(
228
+ {
229
+ "dataset_type": default_dataset_type,
230
+ "function": default_function,
231
+ "xmin": default_data_xmin,
232
+ "xmax": default_data_xmax,
233
+ "sigma": default_sigma,
234
+ "nsample": default_nsample,
235
+ "sample_method": default_sample_method,
236
+ "csv_path": default_csv_file,
237
+ "has_header": default_has_header,
238
+ "xcol": default_xcol,
239
+ "ycol": default_ycol,
240
+ }
241
  )
242
+ manager.value.handle_set_architecture(
 
243
  ARCHITECTURE_PRESETS[default_architecture_preset],
244
  )
245
+ manager.value.handle_set_optimizer(
246
+ {
247
+ "optimizer_type": default_optimizer,
248
+ "learning_rate": default_learning_rate,
249
+ "beta1": default_beta1,
250
+ "beta2": default_beta2,
251
+ "momentum": default_momentum,
252
+ "weight_decay": default_weight_decay,
253
+ "batch_size": default_batch_size,
254
+ }
 
 
 
255
  )
256
 
257
  with gr.Row():
258
  with gr.Column(scale=2):
259
  plot = gr.Plot(
260
+ value=generate_plot(manager.value.get_plot_data()),
261
  )
262
 
263
  with gr.Tab("Data"):
 
372
  xcol,
373
  ycol,
374
  ],
 
 
 
 
375
  outputs=[plot],
376
  )
377
 
 
411
  manager,
412
  architecture,
413
  ],
 
 
 
 
414
  outputs=[plot],
415
  )
416
 
 
475
 
476
  update_optimizer_button = gr.Button("Update Optimizer")
477
  update_optimizer_button.click(
 
 
 
 
478
  fn=handle_set_optimizer,
479
  inputs=[
480
  manager,
 
486
  weight_decay,
487
  batch_size,
488
  ],
 
 
 
 
489
  outputs=[plot],
490
  )
491
 
 
499
  train_button.click(
500
  fn=handle_train_step,
501
  inputs=[manager, step_increment],
 
 
 
 
502
  outputs=[plot],
503
  )
504
 
 
506
  reset_button.click(
507
  fn=handle_reset_model,
508
  inputs=[manager],
 
 
 
 
509
  outputs=[plot],
510
  )
511