| from vizualize_nn import * |
| """## Launch the app""" |
|
|
| device ='cuda' if torch.cuda.is_available() else 'cpu' |
| init_net_and_train_part = partial(init_net_and_train,device=device) |
|
|
| with gr.Blocks() as iface: |
|
|
| tab_train = gr.Tab("Network Training") |
| tab_viz = gr.Tab("Network Visualization") |
|
|
| with tab_train: |
| hidden_units_slider = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="number of neurons in hidden layer") |
| noise_slider = gr.Slider(minimum=0.001, maximum=0.7, step=0.01, value=0.2, label="Noise") |
| epochs_slider = gr.Slider(minimum=1, maximum=50, step=1, value=30, label="Epochs") |
| lr_slider = gr.Slider(minimum=0.001, maximum=0.05, step=0.001, value=0.008, label="Learning Rate") |
| data_points_slider = gr.Slider(minimum=100, maximum=2000, step=4, value=1000, label="Data Points") |
| train_button = gr.Button("Train Network") |
| learning_curve = gr.Plot(label="Learning Curve") |
|
|
| with tab_viz: |
| with (gr.Row() if NETWORK_ORIENTAION != 'h' else dummy_context()): |
| with (gr.Column() if NETWORK_ORIENTAION != 'h' else dummy_context()): |
| with (gr.Row() if NETWORK_ORIENTAION != 'v' else dummy_context()): |
| epoch_viz_slider = gr.Slider(minimum=1, maximum=50, step=1, value=1, label="Visualize Epoch") |
| ner_bounds = gr.Checkbox(label="Invidual neurons decision boundaries") |
| generate_button = gr.Button("Visualize Network") |
| plot_output = gr.Plot(label="Decision Boundary") |
| overall_net_output = gr.Image(type="filepath",label="Network Visualization") |
| with (gr.Column() if NETWORK_ORIENTAION != 'h' else dummy_context()): |
| with gr.Row(): |
| input_x = gr.Number(label="Input X") |
| input_y = gr.Number(label="Input Y") |
| update_button = gr.Button("Check Input") |
| net_activity_sample_output = gr.HTML(label="Network Activity for an Input") |
| |
|
|
| |
| train_button.click(fn=init_net_and_train, inputs=[hidden_units_slider, noise_slider, epochs_slider, data_points_slider, lr_slider], outputs=learning_curve) |
| generate_button.click(fn=generate_images, inputs=[epoch_viz_slider,ner_bounds], outputs=[plot_output, overall_net_output]) |
| update_button.click(fn=get_network_with_inputs, inputs=[epoch_viz_slider, input_x, input_y], outputs=net_activity_sample_output) |
|
|
| |
| |
|
|
| iface.title = "Neural Network Visualization" |
| iface.description = "Adjust parameters and train the network to see its performance and visualization." |
|
|
| |
| iface.launch() |
|
|