Spaces:
Sleeping
Sleeping
Commit ·
5c45b99
1
Parent(s): 16aafd3
Rename DatasetOptions to Dataset for clarity
Browse files- dataset_options.py → dataset.py +7 -7
- mlp_visualizer.py +4 -4
dataset_options.py → dataset.py
RENAMED
|
@@ -38,7 +38,7 @@ def get_data_points(function, xlim=(-1, 1), nsample=10, sigma=0, seed=0):
|
|
| 38 |
return x, y
|
| 39 |
|
| 40 |
|
| 41 |
-
class
|
| 42 |
def __init__(
|
| 43 |
self,
|
| 44 |
mode: str = "generate",
|
|
@@ -89,7 +89,7 @@ class DatasetOptions:
|
|
| 89 |
raise ValueError(f"Unknown dataset mode: {self.mode}")
|
| 90 |
|
| 91 |
def update(self, **kwargs):
|
| 92 |
-
return
|
| 93 |
mode=kwargs.get("mode", self.mode),
|
| 94 |
function=kwargs.get("function", self.function),
|
| 95 |
xmin=kwargs.get("xmin", self.xmin),
|
|
@@ -121,7 +121,7 @@ class DatasetOptions:
|
|
| 121 |
)
|
| 122 |
|
| 123 |
|
| 124 |
-
class
|
| 125 |
def update_mode(self, mode: str, state: gr.State):
|
| 126 |
state = state.update(mode=mode)
|
| 127 |
|
|
@@ -183,11 +183,11 @@ class DatasetOptionsView:
|
|
| 183 |
)
|
| 184 |
with gr.Row():
|
| 185 |
xmin = gr.Number(
|
| 186 |
-
label="
|
| 187 |
value=options.xmin,
|
| 188 |
)
|
| 189 |
xmax = gr.Number(
|
| 190 |
-
label="
|
| 191 |
value=options.xmax,
|
| 192 |
)
|
| 193 |
sigma = gr.Number(
|
|
@@ -215,7 +215,7 @@ class DatasetOptionsView:
|
|
| 215 |
outputs=[state, function, xmin, xmax, sigma, nsample, regenerate, csv_upload],
|
| 216 |
)
|
| 217 |
|
| 218 |
-
#
|
| 219 |
function.submit(
|
| 220 |
lambda f, s: s.update(function=f),
|
| 221 |
inputs=[function, state],
|
|
@@ -247,7 +247,7 @@ class DatasetOptionsView:
|
|
| 247 |
outputs=[state],
|
| 248 |
)
|
| 249 |
|
| 250 |
-
# csv
|
| 251 |
csv_upload.upload(
|
| 252 |
self.upload_csv,
|
| 253 |
inputs=[csv_upload, state],
|
|
|
|
| 38 |
return x, y
|
| 39 |
|
| 40 |
|
| 41 |
+
class Dataset:
|
| 42 |
def __init__(
|
| 43 |
self,
|
| 44 |
mode: str = "generate",
|
|
|
|
| 89 |
raise ValueError(f"Unknown dataset mode: {self.mode}")
|
| 90 |
|
| 91 |
def update(self, **kwargs):
|
| 92 |
+
return Dataset(
|
| 93 |
mode=kwargs.get("mode", self.mode),
|
| 94 |
function=kwargs.get("function", self.function),
|
| 95 |
xmin=kwargs.get("xmin", self.xmin),
|
|
|
|
| 121 |
)
|
| 122 |
|
| 123 |
|
| 124 |
+
class DatasetView:
|
| 125 |
def update_mode(self, mode: str, state: gr.State):
|
| 126 |
state = state.update(mode=mode)
|
| 127 |
|
|
|
|
| 183 |
)
|
| 184 |
with gr.Row():
|
| 185 |
xmin = gr.Number(
|
| 186 |
+
label="x min",
|
| 187 |
value=options.xmin,
|
| 188 |
)
|
| 189 |
xmax = gr.Number(
|
| 190 |
+
label="x max",
|
| 191 |
value=options.xmax,
|
| 192 |
)
|
| 193 |
sigma = gr.Number(
|
|
|
|
| 215 |
outputs=[state, function, xmin, xmax, sigma, nsample, regenerate, csv_upload],
|
| 216 |
)
|
| 217 |
|
| 218 |
+
# generate mode
|
| 219 |
function.submit(
|
| 220 |
lambda f, s: s.update(function=f),
|
| 221 |
inputs=[function, state],
|
|
|
|
| 247 |
outputs=[state],
|
| 248 |
)
|
| 249 |
|
| 250 |
+
# csv mode
|
| 251 |
csv_upload.upload(
|
| 252 |
self.upload_csv,
|
| 253 |
inputs=[csv_upload, state],
|
mlp_visualizer.py
CHANGED
|
@@ -29,7 +29,7 @@ logging.basicConfig(
|
|
| 29 |
)
|
| 30 |
logger = logging.getLogger("ELVIS")
|
| 31 |
|
| 32 |
-
from
|
| 33 |
|
| 34 |
|
| 35 |
class MlpVisualizer:
|
|
@@ -44,7 +44,7 @@ class MlpVisualizer:
|
|
| 44 |
display: none;
|
| 45 |
}"""
|
| 46 |
|
| 47 |
-
def plot(self, dataset_options:
|
| 48 |
print("Plotting")
|
| 49 |
t1 = time.time()
|
| 50 |
fig = plt.figure(figsize=(self.canvas_width / 100., self.canvas_height / 100.0), dpi=100)
|
|
@@ -98,7 +98,7 @@ class MlpVisualizer:
|
|
| 98 |
gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>MLP Training Visualizer</div>")
|
| 99 |
|
| 100 |
# states
|
| 101 |
-
dataset_options = gr.State(
|
| 102 |
|
| 103 |
# GUI elements and layout
|
| 104 |
with gr.Row():
|
|
@@ -111,7 +111,7 @@ class MlpVisualizer:
|
|
| 111 |
|
| 112 |
with gr.Column(scale=1):
|
| 113 |
with gr.Tab("Dataset"):
|
| 114 |
-
dataset_view =
|
| 115 |
dataset_view.build(state=dataset_options)
|
| 116 |
dataset_options.change(
|
| 117 |
fn=self.plot,
|
|
|
|
| 29 |
)
|
| 30 |
logger = logging.getLogger("ELVIS")
|
| 31 |
|
| 32 |
+
from dataset import Dataset, DatasetView, get_function
|
| 33 |
|
| 34 |
|
| 35 |
class MlpVisualizer:
|
|
|
|
| 44 |
display: none;
|
| 45 |
}"""
|
| 46 |
|
| 47 |
+
def plot(self, dataset_options: Dataset):
|
| 48 |
print("Plotting")
|
| 49 |
t1 = time.time()
|
| 50 |
fig = plt.figure(figsize=(self.canvas_width / 100., self.canvas_height / 100.0), dpi=100)
|
|
|
|
| 98 |
gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>MLP Training Visualizer</div>")
|
| 99 |
|
| 100 |
# states
|
| 101 |
+
dataset_options = gr.State(Dataset())
|
| 102 |
|
| 103 |
# GUI elements and layout
|
| 104 |
with gr.Row():
|
|
|
|
| 111 |
|
| 112 |
with gr.Column(scale=1):
|
| 113 |
with gr.Tab("Dataset"):
|
| 114 |
+
dataset_view = DatasetView()
|
| 115 |
dataset_view.build(state=dataset_options)
|
| 116 |
dataset_options.change(
|
| 117 |
fn=self.plot,
|