| | from vizualize_nn import * |
| | """## Launch the app""" |
| |
|
| | device ='cuda' if torch.cuda.is_available() else 'cpu' |
| | init_net_and_train_part = partial(init_net_and_train,device=device) |
| |
|
| | with gr.Blocks() as iface: |
| |
|
| | tab_train = gr.Tab("Network Training") |
| | tab_viz = gr.Tab("Network Visualization") |
| |
|
| | with tab_train: |
| | hidden_units_slider = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="number of neurons in hidden layer") |
| | noise_slider = gr.Slider(minimum=0.001, maximum=0.7, step=0.01, value=0.2, label="Noise") |
| | epochs_slider = gr.Slider(minimum=1, maximum=50, step=1, value=30, label="Epochs") |
| | lr_slider = gr.Slider(minimum=0.001, maximum=0.05, step=0.001, value=0.008, label="Learning Rate") |
| | data_points_slider = gr.Slider(minimum=100, maximum=2000, step=4, value=1000, label="Data Points") |
| | train_button = gr.Button("Train Network") |
| | learning_curve = gr.Plot(label="Learning Curve") |
| |
|
| | with tab_viz: |
| | with (gr.Row() if NETWORK_ORIENTAION != 'h' else dummy_context()): |
| | with (gr.Column() if NETWORK_ORIENTAION != 'h' else dummy_context()): |
| | with (gr.Row() if NETWORK_ORIENTAION != 'v' else dummy_context()): |
| | epoch_viz_slider = gr.Slider(minimum=1, maximum=50, step=1, value=1, label="Visualize Epoch") |
| | ner_bounds = gr.Checkbox(label="Invidual neurons decision boundaries") |
| | generate_button = gr.Button("Visualize Network") |
| | plot_output = gr.Plot(label="Decision Boundary") |
| | overall_net_output = gr.Image(type="filepath",label="Network Visualization") |
| | with (gr.Column() if NETWORK_ORIENTAION != 'h' else dummy_context()): |
| | with gr.Row(): |
| | input_x = gr.Number(label="Input X") |
| | input_y = gr.Number(label="Input Y") |
| | update_button = gr.Button("Check Input") |
| | net_activity_sample_output = gr.HTML(label="Network Activity for an Input") |
| | |
| |
|
| | |
| | train_button.click(fn=init_net_and_train, inputs=[hidden_units_slider, noise_slider, epochs_slider, data_points_slider, lr_slider], outputs=learning_curve) |
| | generate_button.click(fn=generate_images, inputs=[epoch_viz_slider,ner_bounds], outputs=[plot_output, overall_net_output]) |
| | update_button.click(fn=get_network_with_inputs, inputs=[epoch_viz_slider, input_x, input_y], outputs=net_activity_sample_output) |
| |
|
| | |
| | |
| |
|
| | iface.title = "Neural Network Visualization" |
| | iface.description = "Adjust parameters and train the network to see its performance and visualization." |
| |
|
| | |
| | iface.launch() |
| |
|