trohith89 commited on
Commit
5a91dc0
·
verified ·
1 Parent(s): 7864c7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -98
app.py CHANGED
@@ -4,24 +4,26 @@ import matplotlib.pyplot as plt
4
  import seaborn as sns
5
  import graphviz
6
  import time
7
- from sklearn.datasets import make_moons, make_circles, make_classification
8
 
9
  # Set Streamlit page title
10
  st.set_page_config(page_title="Neural Network Trainer", layout="wide")
11
 
12
- # ================= Session State for Training Controls =================
13
  if "epoch" not in st.session_state:
14
  st.session_state.epoch = 0
 
15
  if "running" not in st.session_state:
16
  st.session_state.running = False
17
 
18
- # ================= TRAINING CONTROL PANEL (Top) =================
19
  st.markdown("### Training Controls")
20
- col1, col2, col3, col4, col5, col6, col7, col8, col9 = st.columns(9)
21
 
22
  with col1:
23
  if st.button("↩️ Reset"):
24
  st.session_state.epoch = 0
 
25
  st.session_state.running = False
26
  with col2:
27
  if st.button("▶️ Train"):
@@ -30,86 +32,31 @@ with col3:
30
  if st.button("⏸️ Pause"):
31
  st.session_state.running = False
32
  with col4:
33
- activation = st.selectbox("Activation", ["ReLU", "Sigmoid", "Tanh", "LeakyReLU"])
34
  with col5:
35
- regularization = st.selectbox("Regularization", ["None", "L1", "L2"])
36
- with col6:
37
- reg_rate = st.selectbox("Regularization Rate", [0.0001, 0.001, 0.01, 0.1]) if regularization in ["L1", "L2"] else 0
38
- with col7:
39
- problem_type = st.selectbox("Problem Type", ["Classification", "Regression"])
40
- with col8:
41
  learning_rate = st.selectbox("Learning Rate", [0.0001, 0.001, 0.01, 0.03, 0.1])
42
- with col9:
43
- st.write(f"Epoch: **{st.session_state.epoch}**")
44
 
45
- # 🚀 **Fix:** Run training loop without breaking Streamlit
46
  if st.session_state.running:
47
  time.sleep(1) # Simulating training
48
  st.session_state.epoch += 1
 
49
 
50
- # ================= MAIN LAYOUT =================
51
- col_features, col_hidden, col_output = st.columns([2, 2, 2])
52
-
53
- # ========== FEATURES PANEL (Left) ========== #
54
- with col_features:
55
- st.header("FEATURES")
56
- st.write("Which properties do you want to feed in?")
57
-
58
- x1 = st.checkbox("X₁", value=True)
59
- x2 = st.checkbox("X₂", value=True)
60
- x1_squared = st.checkbox("X₁²")
61
- x2_squared = st.checkbox("X₂²")
62
- x1_x2 = st.checkbox("X₁X₂")
63
- sin_x1 = st.checkbox("sin(X₁)")
64
- sin_x2 = st.checkbox("sin(X₂)")
65
-
66
- # ========== HIDDEN LAYERS PANEL (Middle) ========== #
67
- with col_hidden:
68
- st.header("HIDDEN LAYERS")
69
- hidden_layers = st.slider("Number of Hidden Layers", 1, 7, 2)
70
-
71
- neurons = []
72
- for i in range(hidden_layers):
73
- neurons.append(st.slider(f"Neurons in Layer {i+1}", 1, 20, 4))
74
-
75
- # ========== OUTPUT PANEL (Right) ========== #
76
- with col_output:
77
- st.header("OUTPUT")
78
- st.write("Test Loss: *0.501*")
79
- st.write("Training Loss: *0.507*")
80
-
81
- # Spiral Plot with Updated Color Palette
82
- x = np.linspace(-6, 6, 300)
83
- y = np.sin(x) + np.random.normal(0, 0.1, x.shape)
84
-
85
- fig, ax = plt.subplots()
86
- sns.scatterplot(x=x, y=y, hue=x, palette="plasma", ax=ax)
87
- st.pyplot(fig)
88
-
89
- show_test_data = st.checkbox("Show test data")
90
- discretize_output = st.checkbox("Discretize output")
91
-
92
- # ================= DATASET SELECTION (Sidebar) =================
93
  st.sidebar.header("DATA")
 
 
 
94
 
95
- dataset_option = st.sidebar.radio(
96
- "Which dataset do you want to use?",
97
- ("Moons", "Circles", "Spiral"),
98
- index=2
99
- )
100
-
101
- train_ratio = st.sidebar.slider("Ratio of training to test data:", 10, 90, 50, format="%d%%")
102
- noise = st.sidebar.slider("Noise:", 0.0, 1.0, 0.1, step=0.1)
103
- batch_size = st.sidebar.slider("Batch size:", 1, 100, 10)
104
-
105
- # ================= DATASET GENERATION =================
106
  def generate_data():
107
  if dataset_option == "Moons":
108
- X, y = make_moons(n_samples=500, noise=noise)
109
  elif dataset_option == "Circles":
110
- X, y = make_circles(n_samples=500, noise=noise)
 
 
111
  else:
112
- theta = np.sqrt(np.random.rand(500)) * 2 * np.pi
113
  r = theta
114
  X = np.array([r * np.cos(theta), r * np.sin(theta)]).T
115
  y = (theta % (2 * np.pi)) > np.pi
@@ -117,44 +64,66 @@ def generate_data():
117
 
118
  X, y = generate_data()
119
 
120
- # =================== DRAW NEURAL NETWORK ===================
 
 
 
 
 
 
 
 
 
 
121
  def draw_neural_network():
122
  graph = graphviz.Digraph()
 
 
 
 
 
 
 
 
 
123
 
124
- # Input Layer
125
- graph.node("X1", "X₁", shape="circle", style="filled", fillcolor="lightblue")
126
- graph.node("X2", "X₂", shape="circle", style="filled", fillcolor="lightblue")
127
 
128
- prev_layer = ["X1", "X2"]
 
 
 
 
 
 
 
 
129
 
130
- # Hidden Layers
131
- for i, num_neurons in enumerate(neurons):
132
- current_layer = [f"H{i+1}{j+1}" for j in range(num_neurons)]
133
- for neuron in current_layer:
134
- graph.node(neuron, neuron, shape="circle", style="filled", fillcolor="lightyellow")
135
- for prev in prev_layer:
136
- for curr in current_layer:
137
- graph.edge(prev, curr)
138
- prev_layer = current_layer
139
 
140
- # Output Layer
141
- graph.node("Output", "Output", shape="circle", style="filled", fillcolor="lightgreen")
142
- for neuron in prev_layer:
143
- graph.edge(neuron, "Output")
 
 
 
 
144
 
145
- return graph
 
146
 
147
- # =================== DISPLAY DATA PLOT ===================
148
  st.sidebar.subheader("Dataset Visualization")
149
  fig, ax = plt.subplots()
150
- ax.scatter(X[:, 0], X[:, 1], c=y, cmap="plasma", edgecolors="k")
151
  st.sidebar.pyplot(fig)
152
 
153
- # =================== DISPLAY NEURAL NETWORK ===================
154
- st.graphviz_chart(draw_neural_network())
155
-
156
- # =================== TRAINING STATUS ===================
157
  if st.session_state.running:
158
- st.write("🚀 Training started...")
159
  elif not st.session_state.running and st.session_state.epoch > 0:
160
- st.write("⏸️ Training paused.")
 
4
  import seaborn as sns
5
  import graphviz
6
  import time
7
+ from sklearn.datasets import make_moons, make_circles, make_classification, make_blobs
8
 
9
  # Set Streamlit page title
10
  st.set_page_config(page_title="Neural Network Trainer", layout="wide")
11
 
12
+ # Session State for Training Controls
13
  if "epoch" not in st.session_state:
14
  st.session_state.epoch = 0
15
+ st.session_state.losses = []
16
  if "running" not in st.session_state:
17
  st.session_state.running = False
18
 
19
+ # Training Control Panel
20
  st.markdown("### Training Controls")
21
+ col1, col2, col3, col4, col5 = st.columns(5)
22
 
23
  with col1:
24
  if st.button("↩️ Reset"):
25
  st.session_state.epoch = 0
26
+ st.session_state.losses = []
27
  st.session_state.running = False
28
  with col2:
29
  if st.button("▶️ Train"):
 
32
  if st.button("⏸️ Pause"):
33
  st.session_state.running = False
34
  with col4:
35
+ activation = st.selectbox("Activation", ["ReLU", "Sigmoid", "Tanh"])
36
  with col5:
 
 
 
 
 
 
37
  learning_rate = st.selectbox("Learning Rate", [0.0001, 0.001, 0.01, 0.03, 0.1])
 
 
38
 
 
39
  if st.session_state.running:
40
  time.sleep(1) # Simulating training
41
  st.session_state.epoch += 1
42
+ st.session_state.losses.append(np.exp(-0.1 * st.session_state.epoch) + np.random.uniform(0, 0.05)) # Mock loss
43
 
44
+ # Sidebar - Dataset Selection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  st.sidebar.header("DATA")
46
+ dataset_option = st.sidebar.radio("Select Dataset", ("Moons", "Circles", "Spiral", "Blobs"))
47
+ train_ratio = st.sidebar.slider("Train-Test Split %", 10, 90, 50, format="%d%%")
48
+ noise = st.sidebar.slider("Noise", 0.0, 1.0, 0.1, step=0.1)
49
 
50
+ # Dataset Generation
 
 
 
 
 
 
 
 
 
 
51
  def generate_data():
52
  if dataset_option == "Moons":
53
+ X, y = make_moons(n_samples=1000, noise=noise)
54
  elif dataset_option == "Circles":
55
+ X, y = make_circles(n_samples=1000, noise=noise)
56
+ elif dataset_option == "Blobs":
57
+ X, y = make_blobs(n_samples=1000, centers=2, cluster_std=noise * 5)
58
  else:
59
+ theta = np.sqrt(np.random.rand(1000)) * 2 * np.pi
60
  r = theta
61
  X = np.array([r * np.cos(theta), r * np.sin(theta)]).T
62
  y = (theta % (2 * np.pi)) > np.pi
 
64
 
65
  X, y = generate_data()
66
 
67
+ # Feature Selection
68
+ st.header("FEATURES")
69
+ col_features, col_nn, col_plot = st.columns([2, 2, 3])
70
+
71
+ with col_features:
72
+ x1 = st.checkbox("X₁", value=True)
73
+ x2 = st.checkbox("X₂", value=True)
74
+ selected_features = [i for i, selected in enumerate([x1, x2]) if selected]
75
+ X = X[:, selected_features] if selected_features else np.zeros((X.shape[0], 1))
76
+
77
+ # Neural Network Display
78
  def draw_neural_network():
79
  graph = graphviz.Digraph()
80
+ prev_layer = [f"X{i+1}" for i in range(len(selected_features))]
81
+ for node in prev_layer:
82
+ graph.node(node, node, shape="circle", style="filled", fillcolor="lightblue")
83
+ graph.node("H1", "Hidden Neuron", shape="circle", style="filled", fillcolor="lightyellow")
84
+ for node in prev_layer:
85
+ graph.edge(node, "H1")
86
+ graph.node("Output", "Output", shape="circle", style="filled", fillcolor="lightgreen")
87
+ graph.edge("H1", "Output")
88
+ return graph
89
 
90
+ with col_nn:
91
+ st.graphviz_chart(draw_neural_network())
 
92
 
93
+ # Decision Region Plot
94
+ def plot_decision_regions():
95
+ xx, yy = np.meshgrid(np.linspace(X[:, 0].min() - 1, X[:, 0].max() + 1, 100),
96
+ np.linspace(X[:, 1].min() - 1, X[:, 1].max() + 1, 100))
97
+ Z = np.random.choice([0, 1], size=xx.shape) # Placeholder for model predictions
98
+ fig, ax = plt.subplots()
99
+ ax.contourf(xx, yy, Z, alpha=0.3, cmap="plasma")
100
+ sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=y, palette="plasma", edgecolor="k", ax=ax)
101
+ st.pyplot(fig)
102
 
103
+ with col_plot:
104
+ st.header("Decision Region")
105
+ plot_decision_regions()
 
 
 
 
 
 
106
 
107
+ # Loss Plot
108
+ def plot_loss():
109
+ fig, ax = plt.subplots()
110
+ ax.plot(range(len(st.session_state.losses)), st.session_state.losses, marker='o', linestyle='-')
111
+ ax.set_xlabel("Epoch")
112
+ ax.set_ylabel("Loss")
113
+ ax.set_title("Epoch vs Loss")
114
+ st.pyplot(fig)
115
 
116
+ st.header("Training Progress")
117
+ plot_loss()
118
 
119
+ # Dataset Visualization
120
  st.sidebar.subheader("Dataset Visualization")
121
  fig, ax = plt.subplots()
122
+ sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=y, palette="plasma", edgecolor="k", ax=ax)
123
  st.sidebar.pyplot(fig)
124
 
125
+ # Training Status
 
 
 
126
  if st.session_state.running:
127
+ st.write("🚀 Training in Progress...")
128
  elif not st.session_state.running and st.session_state.epoch > 0:
129
+ st.write("⏸️ Training Paused.")