File size: 2,962 Bytes
5621fe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from vizualize_nn import *
"""## Launch the app"""

device ='cuda' if torch.cuda.is_available() else 'cpu'
init_net_and_train_part = partial(init_net_and_train,device=device)

with gr.Blocks() as iface:

    tab_train = gr.Tab("Network Training")
    tab_viz = gr.Tab("Network Visualization")

    with tab_train:
            hidden_units_slider = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="number of neurons in hidden layer")
            noise_slider = gr.Slider(minimum=0.001, maximum=0.7, step=0.01, value=0.2, label="Noise")
            epochs_slider = gr.Slider(minimum=1, maximum=50, step=1, value=30, label="Epochs")
            lr_slider = gr.Slider(minimum=0.001, maximum=0.05, step=0.001, value=0.008, label="Learning Rate")
            data_points_slider = gr.Slider(minimum=100, maximum=2000, step=4, value=1000, label="Data Points")
            train_button = gr.Button("Train Network")
            learning_curve = gr.Plot(label="Learning Curve")

    with tab_viz:
            with (gr.Row() if NETWORK_ORIENTAION != 'h' else dummy_context()):
                with (gr.Column() if NETWORK_ORIENTAION != 'h' else dummy_context()):
                    with (gr.Row() if NETWORK_ORIENTAION != 'v' else dummy_context()):
                        epoch_viz_slider = gr.Slider(minimum=1, maximum=50, step=1, value=1, label="Visualize Epoch")  # Dynamic update needed here
                        ner_bounds = gr.Checkbox(label="Invidual neurons decision boundaries")
                        generate_button = gr.Button("Visualize Network")
                    plot_output = gr.Plot(label="Decision Boundary")
                    overall_net_output = gr.Image(type="filepath",label="Network Visualization")
                with (gr.Column() if NETWORK_ORIENTAION != 'h' else dummy_context()):
                    with gr.Row():
                        input_x = gr.Number(label="Input X")
                        input_y = gr.Number(label="Input Y")
                        update_button = gr.Button("Check Input")
                    net_activity_sample_output = gr.HTML(label="Network Activity for an Input")
                    # net_activity_sample_output = gr.Image(type="filepath", label="Network Activity for an Input")

    # Set up button click actions
    train_button.click(fn=init_net_and_train, inputs=[hidden_units_slider, noise_slider, epochs_slider, data_points_slider, lr_slider], outputs=learning_curve)
    generate_button.click(fn=generate_images, inputs=[epoch_viz_slider,ner_bounds], outputs=[plot_output, overall_net_output])
    update_button.click(fn=get_network_with_inputs, inputs=[epoch_viz_slider, input_x, input_y], outputs=net_activity_sample_output)

    # # Add Tabs to the interface
    # iface.add_tabs(tab_train, tab_viz)

iface.title = "Neural Network Visualization"
iface.description = "Adjust parameters and train the network to see its performance and visualization."

# Launch the app
iface.launch()