temp2 / luanch_app.py
noamkay's picture
Upload folder using huggingface_hub
5621fe8
from vizualize_nn import *
"""## Launch the app"""
device ='cuda' if torch.cuda.is_available() else 'cpu'
init_net_and_train_part = partial(init_net_and_train,device=device)
with gr.Blocks() as iface:
tab_train = gr.Tab("Network Training")
tab_viz = gr.Tab("Network Visualization")
with tab_train:
hidden_units_slider = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="number of neurons in hidden layer")
noise_slider = gr.Slider(minimum=0.001, maximum=0.7, step=0.01, value=0.2, label="Noise")
epochs_slider = gr.Slider(minimum=1, maximum=50, step=1, value=30, label="Epochs")
lr_slider = gr.Slider(minimum=0.001, maximum=0.05, step=0.001, value=0.008, label="Learning Rate")
data_points_slider = gr.Slider(minimum=100, maximum=2000, step=4, value=1000, label="Data Points")
train_button = gr.Button("Train Network")
learning_curve = gr.Plot(label="Learning Curve")
with tab_viz:
with (gr.Row() if NETWORK_ORIENTAION != 'h' else dummy_context()):
with (gr.Column() if NETWORK_ORIENTAION != 'h' else dummy_context()):
with (gr.Row() if NETWORK_ORIENTAION != 'v' else dummy_context()):
epoch_viz_slider = gr.Slider(minimum=1, maximum=50, step=1, value=1, label="Visualize Epoch") # Dynamic update needed here
ner_bounds = gr.Checkbox(label="Invidual neurons decision boundaries")
generate_button = gr.Button("Visualize Network")
plot_output = gr.Plot(label="Decision Boundary")
overall_net_output = gr.Image(type="filepath",label="Network Visualization")
with (gr.Column() if NETWORK_ORIENTAION != 'h' else dummy_context()):
with gr.Row():
input_x = gr.Number(label="Input X")
input_y = gr.Number(label="Input Y")
update_button = gr.Button("Check Input")
net_activity_sample_output = gr.HTML(label="Network Activity for an Input")
# net_activity_sample_output = gr.Image(type="filepath", label="Network Activity for an Input")
# Set up button click actions
train_button.click(fn=init_net_and_train, inputs=[hidden_units_slider, noise_slider, epochs_slider, data_points_slider, lr_slider], outputs=learning_curve)
generate_button.click(fn=generate_images, inputs=[epoch_viz_slider,ner_bounds], outputs=[plot_output, overall_net_output])
update_button.click(fn=get_network_with_inputs, inputs=[epoch_viz_slider, input_x, input_y], outputs=net_activity_sample_output)
# # Add Tabs to the interface
# iface.add_tabs(tab_train, tab_viz)
iface.title = "Neural Network Visualization"
iface.description = "Adjust parameters and train the network to see its performance and visualization."
# Launch the app
iface.launch()