Spaces:
Build error
Build error
Commit
·
9117cbd
1
Parent(s):
de85694
getting rid of regression tab, thinking about having a custom data tab
Browse fileswhere users can upload their own dataset and hyper parameter tune on
that. and making the classification tab the Classification example tab.
app.py
CHANGED
|
@@ -34,21 +34,21 @@ X_train, X_test, y_train, y_test = _preprocess_digits(seed=1)
|
|
| 34 |
|
| 35 |
def classification(
|
| 36 |
seed: int,
|
| 37 |
-
|
| 38 |
-
|
| 39 |
loss_fn_str: str,
|
| 40 |
epochs: int,
|
| 41 |
hidden_size: int,
|
| 42 |
batch_size: float,
|
| 43 |
learning_rate: float,
|
| 44 |
) -> tuple[gr.Plot, gr.Plot, gr.Label]:
|
| 45 |
-
assert
|
| 46 |
-
assert
|
| 47 |
assert loss_fn_str in nn.LOSSES
|
| 48 |
|
| 49 |
loss_fn: nn.Loss = nn.LOSSES[loss_fn_str]
|
| 50 |
-
h_act_fn: nn.Activation = nn.ACTIVATIONS[
|
| 51 |
-
o_act_fn: nn.Activation = nn.ACTIVATIONS[
|
| 52 |
|
| 53 |
nn_classifier = nn.NN(
|
| 54 |
epochs=epochs,
|
|
@@ -164,7 +164,4 @@ if __name__ == "__main__":
|
|
| 164 |
outputs=plt_outputs + label_output,
|
| 165 |
)
|
| 166 |
|
| 167 |
-
with gr.Tab("Regression"):
|
| 168 |
-
gr.Markdown("### Coming Soon")
|
| 169 |
-
|
| 170 |
interface.launch(show_error=True)
|
|
|
|
| 34 |
|
| 35 |
def classification(
|
| 36 |
seed: int,
|
| 37 |
+
hidden_layer_activation_fn_str: str,
|
| 38 |
+
output_layer_activation_fn_str: str,
|
| 39 |
loss_fn_str: str,
|
| 40 |
epochs: int,
|
| 41 |
hidden_size: int,
|
| 42 |
batch_size: float,
|
| 43 |
learning_rate: float,
|
| 44 |
) -> tuple[gr.Plot, gr.Plot, gr.Label]:
|
| 45 |
+
assert hidden_layer_activation_fn_str in nn.ACTIVATIONS
|
| 46 |
+
assert output_layer_activation_fn_str in nn.ACTIVATIONS
|
| 47 |
assert loss_fn_str in nn.LOSSES
|
| 48 |
|
| 49 |
loss_fn: nn.Loss = nn.LOSSES[loss_fn_str]
|
| 50 |
+
h_act_fn: nn.Activation = nn.ACTIVATIONS[hidden_layer_activation_fn_str]
|
| 51 |
+
o_act_fn: nn.Activation = nn.ACTIVATIONS[output_layer_activation_fn_str]
|
| 52 |
|
| 53 |
nn_classifier = nn.NN(
|
| 54 |
epochs=epochs,
|
|
|
|
| 164 |
outputs=plt_outputs + label_output,
|
| 165 |
)
|
| 166 |
|
|
|
|
|
|
|
|
|
|
| 167 |
interface.launch(show_error=True)
|