trohith89 commited on
Commit
f8cfb29
·
verified ·
1 Parent(s): 21d57e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -156
app.py CHANGED
@@ -20,74 +20,66 @@ st.set_page_config(layout="wide")
20
  st.markdown("""
21
  <style>
22
  .stApp {
23
- background-color: #f5f5f5;
24
- color: #333;
25
  font-family: Arial, sans-serif;
26
  }
27
  h1, h2, h3 {
28
- color: #333;
29
  font-weight: bold;
30
  margin: 0;
31
  padding: 5px 0;
32
  }
33
  .stButton>button {
34
- background-color: #e0e0e0;
35
- color: #333;
36
- border: 2px solid #999;
37
  border-radius: 5px;
38
  padding: 5px 10px;
39
  font-size: 14px;
40
  font-weight: bold;
41
  }
42
  .stButton>button:hover {
43
- background-color: #c0c0c0;
44
- border-color: #777;
45
  }
46
  .stSelectbox, .stSlider {
47
- background-color: #fff;
48
- color: #333;
49
- border: 2px solid #999;
50
  border-radius: 5px;
51
  padding: 5px;
52
  }
53
  .stCheckbox label {
54
- color: #333;
55
  font-size: 14px;
56
  font-weight: bold;
57
  }
58
  .control-bar {
59
- background-color: #e0e0e0;
60
  padding: 10px;
61
- border: 2px solid #999;
62
  border-radius: 5px;
63
  margin-bottom: 10px;
64
  }
65
  .panel {
66
- background-color: #fff;
67
  padding: 10px;
68
- border: 2px solid #999;
69
  border-radius: 5px;
70
  margin: 10px 0;
71
  }
72
  .stSelectbox label, .stSlider label {
73
- color: #333;
74
  font-size: 12px;
75
  font-weight: bold;
76
  }
77
- .play-stop {
78
- background-color: #e0e0e0;
79
- border: 2px solid #999;
80
- border-radius: 5px;
81
- padding: 5px;
82
- margin-right: 10px;
83
- }
84
  </style>
85
  """, unsafe_allow_html=True)
86
 
87
  # Session state initialization
88
  if "training" not in st.session_state:
89
  st.session_state.training = False
90
- st.session_state.epoch = 0
91
  if "num_hidden_layers" not in st.session_state:
92
  st.session_state.num_hidden_layers = 2
93
  if "hidden_layer_neurons" not in st.session_state:
@@ -99,46 +91,36 @@ def reset_session():
99
  st.session_state.clear()
100
  st.session_state.num_hidden_layers = 2
101
  st.session_state.hidden_layer_neurons = [4, 2]
102
- st.session_state.training = False
103
- st.session_state.epoch = 0
104
 
105
  # Two-row top control bar
106
  with st.container():
107
  st.markdown('<div class="control-bar">', unsafe_allow_html=True)
108
- # Row 1: Play/Stop and Epoch
109
- col_play, col_epoch, col1, col2, col3 = st.columns([1, 2, 2, 2, 2])
110
- with col_play:
111
- col_play1, col_play2 = st.columns([1, 1])
112
- with col_play1:
113
- if st.button("⏪", key="rewind", help="Rewind"):
114
- pass # Placeholder for rewind functionality
115
- with col_play2:
116
- if st.button("▶️", key="play", on_click=lambda: setattr(st.session_state, "training", True)):
117
- st.session_state.training = True
118
- if st.button("⏹️", key="stop", on_click=lambda: setattr(st.session_state, "training", False)):
119
- st.session_state.training = False
120
- with col_epoch:
121
- st.write(f"Epoch: {st.session_state.epoch:06d}")
122
  with col1:
123
- learning_rate = st.selectbox("Learning Rate", [0.0001, 0.001, 0.03, 0.1, 0.3, 1], index=2)
124
  with col2:
125
- activation = st.selectbox("Activation", ["ReLU", "Sigmoid", "Tanh"], index=2)
 
126
  with col3:
127
- reg_type = st.selectbox("Regularization", ["None", "L1", "L2"], index=0)
128
-
129
- # Row 2: Other controls
130
- col4, col5, col6, col7, col8 = st.columns([2, 2, 2, 2, 2])
131
  with col4:
132
- reg_rate = st.selectbox("Reg Rate", [0.0, 0.001, 0.01, 0.1, 1], index=0)
133
  with col5:
134
- problem_type = st.selectbox("Problem Type", ["Classification", "Regression"])
 
 
 
135
  with col6:
136
- dataset_options = {"Classification": ["Blobs", "Circles", "Spirals", "XOR"], "Regression": ["Sine Wave"]}
137
- dataset_type = st.selectbox("Dataset", dataset_options[problem_type])
138
  with col7:
139
- batch_size = st.slider("Batch Size", 1, 30, 10)
140
  with col8:
141
- noise_level = st.slider("Noise", 0, 50, 0, step=5)
 
 
 
 
142
  st.markdown('</div>', unsafe_allow_html=True)
143
 
144
  # Dataset generation
@@ -181,41 +163,28 @@ if problem_type == "Classification":
181
  # Main layout
182
  col_left, col_center, col_right = st.columns([1, 2, 1])
183
 
184
- # Left panel: Data and Features with Seaborn
185
  with col_left:
186
  st.markdown('<div class="panel">', unsafe_allow_html=True)
187
- st.subheader("DATA")
188
- st.write("Which dataset do you want to use?")
189
  fig, ax = plt.subplots(figsize=(3, 3))
190
  if problem_type == "Classification":
191
  sns.scatterplot(x=fv[:, 0], y=fv[:, 1], hue=cv, palette="coolwarm", edgecolor="k", alpha=0.7, ax=ax, legend=False)
192
- plt.colorbar(ax.collections[0], ax=ax, label="Class Probability", shrink=0.5)
193
  else:
194
  sns.scatterplot(x=fv[:, 0], y=cv, color="blue", edgecolor="k", alpha=0.7, ax=ax)
195
- ax.set_facecolor("#e6f3ff")
196
  ax.set_xticks([])
197
  ax.set_yticks([])
 
198
  st.pyplot(fig)
199
- st.subheader("FEATURES")
200
- st.write("Which properties do you want to feed in?")
201
  for feature in features.keys():
202
  st.checkbox(feature, value=feature in ["X1", "X2"], key=feature)
203
- col_train, col_noise, col_batch = st.columns(3)
204
- with col_train:
205
- train_ratio = st.slider("Ratio of training to test data: 50%", 10, 90, 50, 10) / 100
206
- with col_noise:
207
- noise_level = st.slider("Noise: 0", 0, 50, 0, step=5)
208
- with col_batch:
209
- batch_size = st.slider("Batch size: 10", 1, 30, 10)
210
- st.button("REGENERATE", key="regenerate")
211
- st.checkbox("Show test data", key="show_test_data")
212
- st.checkbox("Discretize output", key="discretize_output")
213
  st.markdown('</div>', unsafe_allow_html=True)
214
 
215
- # Center panel: Horizontal ANN Visualization
216
  with col_center:
217
  st.markdown('<div class="panel">', unsafe_allow_html=True)
218
- st.subheader("HIDDEN LAYERS")
219
 
220
  def draw_nn(features, hidden_neurons):
221
  G = nx.DiGraph()
@@ -229,16 +198,16 @@ with col_center:
229
  for node in layer:
230
  G.add_node(node, layer=layer_idx)
231
  if layer_idx == 0:
232
- node_colors[node] = "#ff9f40" # Orange for input
233
  elif layer_idx == len(all_layers) - 1:
234
- node_colors[node] = "#ff9f40" # Orange for output
235
  else:
236
- node_colors[node] = "#40a0ff" # Blue for hidden
237
 
238
  for i in range(len(all_layers) - 1):
239
  for node1 in all_layers[i]:
240
  for node2 in all_layers[i + 1]:
241
- G.add_edge(node1, node2, weight=np.random.uniform(0.1, 1.0)) # Random weight for thickness
242
 
243
  pos = nx.multipartite_layout(G, subset_key="layer", align="vertical")
244
  pos_rotated = {node: (-y, x) for node, (x, y) in pos.items()}
@@ -246,42 +215,22 @@ with col_center:
246
  pos_rotated[node] = (pos_rotated[node][0] * 2, pos_rotated[node][1] * 2)
247
 
248
  fig, ax = plt.subplots(figsize=(8, 4))
249
- ax.set_facecolor("#f5f5f5")
250
- edge_colors = [plt.cm.RdBu(G[u][v]['weight']) for u, v in G.edges()]
251
  nx.draw(
252
  G, pos_rotated,
253
  with_labels=True,
254
  node_color=[node_colors[node] for node in G.nodes()],
255
- edge_color=edge_colors,
256
- node_shape="s", # Square nodes
257
- node_size=1200,
258
  font_size=8,
259
  font_color="black",
260
  font_weight="bold",
261
  edgecolors="black",
262
- width=[G[u][v]['weight'] * 2 for u, v in G.edges()], # Vary thickness
263
- arrows=False,
264
  ax=ax
265
  )
266
- plt.title("Neural Network Structure", color="#333", fontsize=12, pad=10)
267
-
268
- # Add + and - buttons for layers
269
- col_plus, col_minus = st.columns([1, 1])
270
- with col_plus:
271
- st.button("+", key="add_layer", on_click=add_layer)
272
- with col_minus:
273
- st.button("−", key="remove_layer", on_click=remove_layer)
274
-
275
- # Layer neuron controls
276
- for i in range(st.session_state.num_hidden_layers):
277
- col_dec, col_label, col_inc = st.columns([1, 2, 1])
278
- with col_dec:
279
- st.button("−", key=f"dec_{i}", on_click=decrease_neurons, args=(i,))
280
- with col_label:
281
- st.write(f"{st.session_state.hidden_layer_neurons[i]} neurons")
282
- with col_inc:
283
- st.button("+", key=f"inc_{i}", on_click=increase_neurons, args=(i,))
284
-
285
  return fig
286
 
287
  st.pyplot(draw_nn(selected_features, st.session_state.hidden_layer_neurons))
@@ -304,54 +253,33 @@ with col_center:
304
  if st.session_state.hidden_layer_neurons[i] > 1:
305
  st.session_state.hidden_layer_neurons[i] -= 1
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  st.markdown('</div>', unsafe_allow_html=True)
308
 
309
- # Right panel: Output
310
  with col_right:
311
  st.markdown('<div class="panel">', unsafe_allow_html=True)
312
- st.subheader("OUTPUT")
313
- if st.session_state.training:
314
- train_loss = 0.505 # Simulated value
315
- test_loss = 0.513 # Simulated value
316
- st.write(f"Training loss: {train_loss:.3f}")
317
- st.write(f"Test loss: {test_loss:.3f}")
318
- col1, col2 = st.columns(2)
319
- with col1:
320
- plt.figure(figsize=(3, 3))
321
- if problem_type == "Classification":
322
- y_pred_proba = model.predict(selected_data, verbose=0)
323
- y_pred = (y_pred_proba > 0.5).astype(int).ravel()
324
- try:
325
- plot_decision_regions(selected_data[:, :2], cv, clf=model, legend=2)
326
- except Exception as e:
327
- st.warning(f"Decision region plot failed: {e}")
328
- xx, yy = np.meshgrid(np.linspace(selected_data[:, 0].min(), selected_data[:, 0].max(), 100),
329
- np.linspace(selected_data[:, 1].min(), selected_data[:, 1].max(), 100))
330
- grid = np.c_[xx.ravel(), yy.ravel()]
331
- Z = model.predict(grid, verbose=0)
332
- Z = (Z > 0.5).astype(int).reshape(xx.shape)
333
- plt.contourf(xx, yy, Z, alpha=0.3, cmap="coolwarm")
334
- plt.scatter(selected_data[:, 0], selected_data[:, 1], c=cv, cmap="coolwarm", edgecolors="k", alpha=0.7)
335
- else:
336
- y_pred = model.predict(selected_data, verbose=0)
337
- plt.scatter(selected_data[:, 0], cv, c="blue", alpha=0.5)
338
- plt.plot(selected_data[:, 0], y_pred, "r-", linewidth=2)
339
- plt.gca().set_facecolor("#e6f3ff")
340
- plt.xticks([])
341
- plt.yticks([])
342
- st.pyplot(plt)
343
- with col2:
344
- fig, ax = plt.subplots(figsize=(3, 3))
345
- ax.plot([1, 2, 3], [0.5, 0.5, 0.5], "b-", label="Train") # Simulated loss
346
- ax.plot([1, 2, 3], [0.51, 0.51, 0.51], "r--", label="Val") # Simulated loss
347
- ax.legend()
348
- ax.set_facecolor("#e6f3ff")
349
- st.pyplot(fig)
350
- st.write("Colors shows data, neuron and weight values.")
351
- st.markdown('</div>', unsafe_allow_html=True)
352
 
353
- # Training logic (moved outside for clarity)
354
- if st.session_state.training:
355
  def create_model(input_dim, neurons):
356
  model = Sequential()
357
  model.add(Input(shape=(input_dim,)))
@@ -372,7 +300,6 @@ if st.session_state.training:
372
  self.placeholder = st.empty()
373
 
374
  def on_epoch_end(self, epoch, logs=None):
375
- st.session_state.epoch = epoch + 1
376
  self.losses["Epoch"].append(epoch + 1)
377
  self.losses["Train Loss"].append(logs["loss"])
378
  self.losses["Val Loss"].append(logs["val_loss"])
@@ -381,24 +308,28 @@ if st.session_state.training:
381
  with col1:
382
  plt.figure(figsize=(3, 3))
383
  if problem_type == "Classification":
384
- y_pred_proba = self.model.predict(self.X, verbose=0)
 
 
385
  y_pred = (y_pred_proba > 0.5).astype(int).ravel()
386
  try:
387
- plot_decision_regions(self.X[:, :2], self.y, clf=self.model, legend=2)
 
388
  except Exception as e:
389
  st.warning(f"Decision region plot failed: {e}")
390
- xx, yy = np.meshgrid(np.linspace(self.X[:, 0].min(), self.X[:, 0].max(), 100),
391
- np.linspace(self.X[:, 1].min(), self.X[:, 1].max(), 100))
392
  grid = np.c_[xx.ravel(), yy.ravel()]
393
  Z = self.model.predict(grid, verbose=0)
394
  Z = (Z > 0.5).astype(int).reshape(xx.shape)
 
395
  plt.contourf(xx, yy, Z, alpha=0.3, cmap="coolwarm")
396
- plt.scatter(self.X[:, 0], self.X[:, 1], c=self.y, cmap="coolwarm", edgecolors="k", alpha=0.7)
397
  else:
398
  y_pred = self.model.predict(self.X, verbose=0)
399
  plt.scatter(self.X[:, 0], self.y, c="blue", alpha=0.5)
400
  plt.plot(self.X[:, 0], y_pred, "r-", linewidth=2)
401
- plt.gca().set_facecolor("#e6f3ff")
402
  plt.xticks([])
403
  plt.yticks([])
404
  st.pyplot(plt)
@@ -407,9 +338,11 @@ if st.session_state.training:
407
  ax.plot(self.losses["Epoch"], self.losses["Train Loss"], "b-", label="Train")
408
  ax.plot(self.losses["Epoch"], self.losses["Val Loss"], "r--", label="Val")
409
  ax.legend()
410
- ax.set_facecolor("#e6f3ff")
411
  st.pyplot(fig)
412
 
413
- model = create_model(len(selected_features), st.session_state.hidden_layer_neurons)
414
- callback = OutputCallback(selected_data, cv)
415
- model.fit(selected_data, cv, epochs=999999, batch_size=batch_size, validation_split=1-train_ratio, callbacks=[callback], verbose=0)
 
 
 
20
  st.markdown("""
21
  <style>
22
  .stApp {
23
+ background-color: #252830;
24
+ color: white;
25
  font-family: Arial, sans-serif;
26
  }
27
  h1, h2, h3 {
28
+ color: white;
29
  font-weight: bold;
30
  margin: 0;
31
  padding: 5px 0;
32
  }
33
  .stButton>button {
34
+ background-color: #555;
35
+ color: white;
36
+ border: 2px solid #777;
37
  border-radius: 5px;
38
  padding: 5px 10px;
39
  font-size: 14px;
40
  font-weight: bold;
41
  }
42
  .stButton>button:hover {
43
+ background-color: #777;
44
+ border-color: #999;
45
  }
46
  .stSelectbox, .stSlider {
47
+ background-color: #333;
48
+ color: white;
49
+ border: 2px solid #777;
50
  border-radius: 5px;
51
  padding: 5px;
52
  }
53
  .stCheckbox label {
54
+ color: white;
55
  font-size: 14px;
56
  font-weight: bold;
57
  }
58
  .control-bar {
59
+ background-color: #1e2126;
60
  padding: 10px;
61
+ border: 2px solid #333;
62
  border-radius: 5px;
63
  margin-bottom: 10px;
64
  }
65
  .panel {
66
+ background-color: #2e3238;
67
  padding: 10px;
68
+ border: 2px solid #777;
69
  border-radius: 5px;
70
  margin: 10px 0;
71
  }
72
  .stSelectbox label, .stSlider label {
73
+ color: white;
74
  font-size: 12px;
75
  font-weight: bold;
76
  }
 
 
 
 
 
 
 
77
  </style>
78
  """, unsafe_allow_html=True)
79
 
80
  # Session state initialization
81
  if "training" not in st.session_state:
82
  st.session_state.training = False
 
83
  if "num_hidden_layers" not in st.session_state:
84
  st.session_state.num_hidden_layers = 2
85
  if "hidden_layer_neurons" not in st.session_state:
 
91
  st.session_state.clear()
92
  st.session_state.num_hidden_layers = 2
93
  st.session_state.hidden_layer_neurons = [4, 2]
 
 
94
 
95
  # Two-row top control bar
96
  with st.container():
97
  st.markdown('<div class="control-bar">', unsafe_allow_html=True)
98
+ # Row 1
99
+ col1, col2, col3, col4, col5 = st.columns(5)
 
 
 
 
 
 
 
 
 
 
 
 
100
  with col1:
101
+ problem_type = st.selectbox("Problem Type", ["Classification", "Regression"])
102
  with col2:
103
+ dataset_options = {"Classification": ["Blobs", "Circles", "Spirals", "XOR"], "Regression": ["Sine Wave"]}
104
+ dataset_type = st.selectbox("Dataset", dataset_options[problem_type])
105
  with col3:
106
+ learning_rate = st.selectbox("Learning Rate", [0.0001, 0.001, 0.03, 0.1, 0.3, 1], index=2)
 
 
 
107
  with col4:
108
+ activation = st.selectbox("Activation", ["ReLU", "Sigmoid", "Tanh"], index=2)
109
  with col5:
110
+ batch_size = st.slider("Batch Size", 1, 30, 10)
111
+
112
+ # Row 2
113
+ col6, col7, col8, col9, col10 = st.columns(5)
114
  with col6:
115
+ noise_level = st.slider("Noise", 0, 50, 0, step=5)
 
116
  with col7:
117
+ reg_type = st.selectbox("Regularization", ["None", "L1", "L2"], index=0)
118
  with col8:
119
+ reg_rate = st.selectbox("Reg Rate", [0.0, 0.001, 0.01, 0.1, 1], index=0)
120
+ with col9:
121
+ train_ratio = st.slider("Train %", 10, 90, 50, 10) / 100
122
+ with col10:
123
+ st.button("Reset", key="reset_global", on_click=reset_session)
124
  st.markdown('</div>', unsafe_allow_html=True)
125
 
126
  # Dataset generation
 
163
  # Main layout
164
  col_left, col_center, col_right = st.columns([1, 2, 1])
165
 
166
+ # Left panel: Dataset with Seaborn
167
  with col_left:
168
  st.markdown('<div class="panel">', unsafe_allow_html=True)
169
+ st.subheader("Data")
 
170
  fig, ax = plt.subplots(figsize=(3, 3))
171
  if problem_type == "Classification":
172
  sns.scatterplot(x=fv[:, 0], y=fv[:, 1], hue=cv, palette="coolwarm", edgecolor="k", alpha=0.7, ax=ax, legend=False)
 
173
  else:
174
  sns.scatterplot(x=fv[:, 0], y=cv, color="blue", edgecolor="k", alpha=0.7, ax=ax)
 
175
  ax.set_xticks([])
176
  ax.set_yticks([])
177
+ ax.set_facecolor("#333")
178
  st.pyplot(fig)
179
+ st.subheader("Features")
 
180
  for feature in features.keys():
181
  st.checkbox(feature, value=feature in ["X1", "X2"], key=feature)
 
 
 
 
 
 
 
 
 
 
182
  st.markdown('</div>', unsafe_allow_html=True)
183
 
184
+ # Center panel: Horizontal Network Visualization
185
  with col_center:
186
  st.markdown('<div class="panel">', unsafe_allow_html=True)
187
+ st.subheader("Network")
188
 
189
  def draw_nn(features, hidden_neurons):
190
  G = nx.DiGraph()
 
198
  for node in layer:
199
  G.add_node(node, layer=layer_idx)
200
  if layer_idx == 0:
201
+ node_colors[node] = "#90EE90" # Green for input
202
  elif layer_idx == len(all_layers) - 1:
203
+ node_colors[node] = "#FFA07A" # Orange for output
204
  else:
205
+ node_colors[node] = "#87CEFA" # Blue for hidden
206
 
207
  for i in range(len(all_layers) - 1):
208
  for node1 in all_layers[i]:
209
  for node2 in all_layers[i + 1]:
210
+ G.add_edge(node1, node2)
211
 
212
  pos = nx.multipartite_layout(G, subset_key="layer", align="vertical")
213
  pos_rotated = {node: (-y, x) for node, (x, y) in pos.items()}
 
215
  pos_rotated[node] = (pos_rotated[node][0] * 2, pos_rotated[node][1] * 2)
216
 
217
  fig, ax = plt.subplots(figsize=(8, 4))
218
+ ax.set_facecolor("#252830")
 
219
  nx.draw(
220
  G, pos_rotated,
221
  with_labels=True,
222
  node_color=[node_colors[node] for node in G.nodes()],
223
+ edge_color="white",
224
+ node_size=600,
 
225
  font_size=8,
226
  font_color="black",
227
  font_weight="bold",
228
  edgecolors="black",
229
+ width=1.0,
230
+ arrows=True,
231
  ax=ax
232
  )
233
+ plt.title("Neural Network Structure", color="white", fontsize=12, pad=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  return fig
235
 
236
  st.pyplot(draw_nn(selected_features, st.session_state.hidden_layer_neurons))
 
253
  if st.session_state.hidden_layer_neurons[i] > 1:
254
  st.session_state.hidden_layer_neurons[i] -= 1
255
 
256
+ for i in range(st.session_state.num_hidden_layers):
257
+ col1, col2, col3 = st.columns([1, 2, 1])
258
+ with col1:
259
+ st.button("−", key=f"dec_{i}", on_click=decrease_neurons, args=(i,))
260
+ with col2:
261
+ st.write(f"Layer {i+1}: {st.session_state.hidden_layer_neurons[i]}")
262
+ with col3:
263
+ st.button("+", key=f"inc_{i}", on_click=increase_neurons, args=(i,))
264
+ col_btn1, col_btn2 = st.columns(2)
265
+ with col_btn1:
266
+ st.button("Add Layer", on_click=add_layer)
267
+ with col_btn2:
268
+ st.button("Remove Layer", on_click=remove_layer)
269
  st.markdown('</div>', unsafe_allow_html=True)
270
 
271
+ # Right panel: Output and Training
272
  with col_right:
273
  st.markdown('<div class="panel">', unsafe_allow_html=True)
274
+ st.subheader("Output")
275
+ col_start, col_stop = st.columns(2)
276
+ with col_start:
277
+ if st.button("▶️ Play"):
278
+ st.session_state.training = True
279
+ with col_stop:
280
+ if st.button("⏹️ Stop"):
281
+ st.session_state.training = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
 
 
283
  def create_model(input_dim, neurons):
284
  model = Sequential()
285
  model.add(Input(shape=(input_dim,)))
 
300
  self.placeholder = st.empty()
301
 
302
  def on_epoch_end(self, epoch, logs=None):
 
303
  self.losses["Epoch"].append(epoch + 1)
304
  self.losses["Train Loss"].append(logs["loss"])
305
  self.losses["Val Loss"].append(logs["val_loss"])
 
308
  with col1:
309
  plt.figure(figsize=(3, 3))
310
  if problem_type == "Classification":
311
+ # Ensure we use only the first two features for 2D plotting
312
+ X_2d = self.X[:, :2] # Use only X1 and X2 for decision boundary
313
+ y_pred_proba = self.model.predict(X_2d, verbose=0)
314
  y_pred = (y_pred_proba > 0.5).astype(int).ravel()
315
  try:
316
+ plot_decision_regions(X_2d, self.y, clf=self.model, legend=2, colors='blue,red')
317
+ plt.scatter(X_2d[:, 0], X_2d[:, 1], c=self.y, cmap='coolwarm', edgecolors='k', alpha=0.7)
318
  except Exception as e:
319
  st.warning(f"Decision region plot failed: {e}")
320
+ xx, yy = np.meshgrid(np.linspace(X_2d[:, 0].min(), X_2d[:, 0].max(), 100),
321
+ np.linspace(X_2d[:, 1].min(), X_2d[:, 1].max(), 100))
322
  grid = np.c_[xx.ravel(), yy.ravel()]
323
  Z = self.model.predict(grid, verbose=0)
324
  Z = (Z > 0.5).astype(int).reshape(xx.shape)
325
+ plt.contour(xx, yy, Z, levels=[0.5], colors='black', linewidths=2) # Clear decision boundary
326
  plt.contourf(xx, yy, Z, alpha=0.3, cmap="coolwarm")
327
+ plt.scatter(X_2d[:, 0], X_2d[:, 1], c=self.y, cmap="coolwarm", edgecolors="k", alpha=0.7)
328
  else:
329
  y_pred = self.model.predict(self.X, verbose=0)
330
  plt.scatter(self.X[:, 0], self.y, c="blue", alpha=0.5)
331
  plt.plot(self.X[:, 0], y_pred, "r-", linewidth=2)
332
+ plt.gca().set_facecolor("#333")
333
  plt.xticks([])
334
  plt.yticks([])
335
  st.pyplot(plt)
 
338
  ax.plot(self.losses["Epoch"], self.losses["Train Loss"], "b-", label="Train")
339
  ax.plot(self.losses["Epoch"], self.losses["Val Loss"], "r--", label="Val")
340
  ax.legend()
341
+ ax.set_facecolor("#333")
342
  st.pyplot(fig)
343
 
344
+ if st.session_state.training:
345
+ model = create_model(len(selected_features), st.session_state.hidden_layer_neurons)
346
+ callback = OutputCallback(selected_data, cv)
347
+ model.fit(selected_data, cv, epochs=999999, batch_size=batch_size, validation_split=1-train_ratio, callbacks=[callback], verbose=0)
348
+ st.markdown('</div>', unsafe_allow_html=True)