trohith89 commited on
Commit
21d57e7
·
verified ·
1 Parent(s): a74998f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -82
app.py CHANGED
@@ -15,71 +15,79 @@ from keras.losses import MeanSquaredError, BinaryCrossentropy
15
  from keras.regularizers import l2, l1
16
  from keras.callbacks import Callback
17
 
18
- # Set wide layout and updated CSS
19
  st.set_page_config(layout="wide")
20
  st.markdown("""
21
  <style>
22
  .stApp {
23
- background-color: #252830;
24
- color: white;
25
  font-family: Arial, sans-serif;
26
  }
27
  h1, h2, h3 {
28
- color: white;
29
  font-weight: bold;
30
  margin: 0;
31
  padding: 5px 0;
32
  }
33
  .stButton>button {
34
- background-color: #555;
35
- color: white;
36
- border: 2px solid #777;
37
  border-radius: 5px;
38
  padding: 5px 10px;
39
  font-size: 14px;
40
  font-weight: bold;
41
  }
42
  .stButton>button:hover {
43
- background-color: #777;
44
- border-color: #999;
45
  }
46
  .stSelectbox, .stSlider {
47
- background-color: #333;
48
- color: white;
49
- border: 2px solid #777;
50
  border-radius: 5px;
51
  padding: 5px;
52
  }
53
  .stCheckbox label {
54
- color: white;
55
  font-size: 14px;
56
  font-weight: bold;
57
  }
58
  .control-bar {
59
- background-color: #1e2126;
60
  padding: 10px;
61
- border: 2px solid #333;
62
  border-radius: 5px;
63
  margin-bottom: 10px;
64
  }
65
  .panel {
66
- background-color: #2e3238;
67
  padding: 10px;
68
- border: 2px solid #777;
69
  border-radius: 5px;
70
  margin: 10px 0;
71
  }
72
  .stSelectbox label, .stSlider label {
73
- color: white;
74
  font-size: 12px;
75
  font-weight: bold;
76
  }
 
 
 
 
 
 
 
77
  </style>
78
  """, unsafe_allow_html=True)
79
 
80
  # Session state initialization
81
  if "training" not in st.session_state:
82
  st.session_state.training = False
 
83
  if "num_hidden_layers" not in st.session_state:
84
  st.session_state.num_hidden_layers = 2
85
  if "hidden_layer_neurons" not in st.session_state:
@@ -91,36 +99,46 @@ def reset_session():
91
  st.session_state.clear()
92
  st.session_state.num_hidden_layers = 2
93
  st.session_state.hidden_layer_neurons = [4, 2]
 
 
94
 
95
  # Two-row top control bar
96
  with st.container():
97
  st.markdown('<div class="control-bar">', unsafe_allow_html=True)
98
- # Row 1
99
- col1, col2, col3, col4, col5 = st.columns(5)
 
 
 
 
 
 
 
 
 
 
 
 
100
  with col1:
101
- problem_type = st.selectbox("Problem Type", ["Classification", "Regression"])
102
  with col2:
103
- dataset_options = {"Classification": ["Blobs", "Circles", "Spirals", "XOR"], "Regression": ["Sine Wave"]}
104
- dataset_type = st.selectbox("Dataset", dataset_options[problem_type])
105
  with col3:
106
- learning_rate = st.selectbox("Learning Rate", [0.0001, 0.001, 0.03, 0.1, 0.3, 1], index=2)
 
 
 
107
  with col4:
108
- activation = st.selectbox("Activation", ["ReLU", "Sigmoid", "Tanh"], index=2)
109
  with col5:
110
- batch_size = st.slider("Batch Size", 1, 30, 10)
111
-
112
- # Row 2
113
- col6, col7, col8, col9, col10 = st.columns(5)
114
  with col6:
115
- noise_level = st.slider("Noise", 0, 50, 0, step=5)
 
116
  with col7:
117
- reg_type = st.selectbox("Regularization", ["None", "L1", "L2"], index=0)
118
  with col8:
119
- reg_rate = st.selectbox("Reg Rate", [0.0, 0.001, 0.01, 0.1, 1], index=0)
120
- with col9:
121
- train_ratio = st.slider("Train %", 10, 90, 50, 10) / 100
122
- with col10:
123
- st.button("Reset", key="reset_global", on_click=reset_session)
124
  st.markdown('</div>', unsafe_allow_html=True)
125
 
126
  # Dataset generation
@@ -163,28 +181,41 @@ if problem_type == "Classification":
163
  # Main layout
164
  col_left, col_center, col_right = st.columns([1, 2, 1])
165
 
166
- # Left panel: Dataset with Seaborn
167
  with col_left:
168
  st.markdown('<div class="panel">', unsafe_allow_html=True)
169
- st.subheader("Data")
 
170
  fig, ax = plt.subplots(figsize=(3, 3))
171
  if problem_type == "Classification":
172
  sns.scatterplot(x=fv[:, 0], y=fv[:, 1], hue=cv, palette="coolwarm", edgecolor="k", alpha=0.7, ax=ax, legend=False)
 
173
  else:
174
  sns.scatterplot(x=fv[:, 0], y=cv, color="blue", edgecolor="k", alpha=0.7, ax=ax)
 
175
  ax.set_xticks([])
176
  ax.set_yticks([])
177
- ax.set_facecolor("#333")
178
  st.pyplot(fig)
179
- st.subheader("Features")
 
180
  for feature in features.keys():
181
  st.checkbox(feature, value=feature in ["X1", "X2"], key=feature)
 
 
 
 
 
 
 
 
 
 
182
  st.markdown('</div>', unsafe_allow_html=True)
183
 
184
- # Center panel: Horizontal Network Visualization
185
  with col_center:
186
  st.markdown('<div class="panel">', unsafe_allow_html=True)
187
- st.subheader("Network")
188
 
189
  def draw_nn(features, hidden_neurons):
190
  G = nx.DiGraph()
@@ -198,16 +229,16 @@ with col_center:
198
  for node in layer:
199
  G.add_node(node, layer=layer_idx)
200
  if layer_idx == 0:
201
- node_colors[node] = "#90EE90"
202
  elif layer_idx == len(all_layers) - 1:
203
- node_colors[node] = "#FFA07A"
204
  else:
205
- node_colors[node] = "#87CEFA"
206
 
207
  for i in range(len(all_layers) - 1):
208
  for node1 in all_layers[i]:
209
  for node2 in all_layers[i + 1]:
210
- G.add_edge(node1, node2)
211
 
212
  pos = nx.multipartite_layout(G, subset_key="layer", align="vertical")
213
  pos_rotated = {node: (-y, x) for node, (x, y) in pos.items()}
@@ -215,22 +246,42 @@ with col_center:
215
  pos_rotated[node] = (pos_rotated[node][0] * 2, pos_rotated[node][1] * 2)
216
 
217
  fig, ax = plt.subplots(figsize=(8, 4))
218
- ax.set_facecolor("#252830")
 
219
  nx.draw(
220
  G, pos_rotated,
221
  with_labels=True,
222
  node_color=[node_colors[node] for node in G.nodes()],
223
- edge_color="white",
224
- node_size=600,
 
225
  font_size=8,
226
  font_color="black",
227
  font_weight="bold",
228
  edgecolors="black",
229
- width=1.0,
230
- arrows=True,
231
  ax=ax
232
  )
233
- plt.title("Neural Network Structure", color="white", fontsize=12, pad=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  return fig
235
 
236
  st.pyplot(draw_nn(selected_features, st.session_state.hidden_layer_neurons))
@@ -253,33 +304,54 @@ with col_center:
253
  if st.session_state.hidden_layer_neurons[i] > 1:
254
  st.session_state.hidden_layer_neurons[i] -= 1
255
 
256
- for i in range(st.session_state.num_hidden_layers):
257
- col1, col2, col3 = st.columns([1, 2, 1])
258
- with col1:
259
- st.button("−", key=f"dec_{i}", on_click=decrease_neurons, args=(i,)) # Unicode minus
260
- with col2:
261
- st.write(f"Layer {i+1}: {st.session_state.hidden_layer_neurons[i]}")
262
- with col3:
263
- st.button("+", key=f"inc_{i}", on_click=increase_neurons, args=(i,))
264
- col_btn1, col_btn2 = st.columns(2)
265
- with col_btn1:
266
- st.button("Add Layer", on_click=add_layer)
267
- with col_btn2:
268
- st.button("Remove Layer", on_click=remove_layer)
269
  st.markdown('</div>', unsafe_allow_html=True)
270
 
271
- # Right panel: Output and Training
272
  with col_right:
273
  st.markdown('<div class="panel">', unsafe_allow_html=True)
274
- st.subheader("Output")
275
- col_start, col_stop = st.columns(2)
276
- with col_start:
277
- if st.button("▶️ Play"):
278
- st.session_state.training = True
279
- with col_stop:
280
- if st.button("⏹️ Stop"):
281
- st.session_state.training = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
 
 
283
  def create_model(input_dim, neurons):
284
  model = Sequential()
285
  model.add(Input(shape=(input_dim,)))
@@ -300,6 +372,7 @@ with col_right:
300
  self.placeholder = st.empty()
301
 
302
  def on_epoch_end(self, epoch, logs=None):
 
303
  self.losses["Epoch"].append(epoch + 1)
304
  self.losses["Train Loss"].append(logs["loss"])
305
  self.losses["Val Loss"].append(logs["val_loss"])
@@ -311,7 +384,7 @@ with col_right:
311
  y_pred_proba = self.model.predict(self.X, verbose=0)
312
  y_pred = (y_pred_proba > 0.5).astype(int).ravel()
313
  try:
314
- plot_decision_regions(self.X, self.y, clf=self.model, legend=2)
315
  except Exception as e:
316
  st.warning(f"Decision region plot failed: {e}")
317
  xx, yy = np.meshgrid(np.linspace(self.X[:, 0].min(), self.X[:, 0].max(), 100),
@@ -325,7 +398,7 @@ with col_right:
325
  y_pred = self.model.predict(self.X, verbose=0)
326
  plt.scatter(self.X[:, 0], self.y, c="blue", alpha=0.5)
327
  plt.plot(self.X[:, 0], y_pred, "r-", linewidth=2)
328
- plt.gca().set_facecolor("#333")
329
  plt.xticks([])
330
  plt.yticks([])
331
  st.pyplot(plt)
@@ -334,11 +407,9 @@ with col_right:
334
  ax.plot(self.losses["Epoch"], self.losses["Train Loss"], "b-", label="Train")
335
  ax.plot(self.losses["Epoch"], self.losses["Val Loss"], "r--", label="Val")
336
  ax.legend()
337
- ax.set_facecolor("#333")
338
  st.pyplot(fig)
339
 
340
- if st.session_state.training:
341
- model = create_model(len(selected_features), st.session_state.hidden_layer_neurons)
342
- callback = OutputCallback(selected_data, cv)
343
- model.fit(selected_data, cv, epochs=999999, batch_size=batch_size, validation_split=1-train_ratio, callbacks=[callback], verbose=0)
344
- st.markdown('</div>', unsafe_allow_html=True)
 
15
  from keras.regularizers import l2, l1
16
  from keras.callbacks import Callback
17
 
18
+ # Set wide layout and TensorFlow Playground CSS
19
  st.set_page_config(layout="wide")
20
  st.markdown("""
21
  <style>
22
  .stApp {
23
+ background-color: #f5f5f5;
24
+ color: #333;
25
  font-family: Arial, sans-serif;
26
  }
27
  h1, h2, h3 {
28
+ color: #333;
29
  font-weight: bold;
30
  margin: 0;
31
  padding: 5px 0;
32
  }
33
  .stButton>button {
34
+ background-color: #e0e0e0;
35
+ color: #333;
36
+ border: 2px solid #999;
37
  border-radius: 5px;
38
  padding: 5px 10px;
39
  font-size: 14px;
40
  font-weight: bold;
41
  }
42
  .stButton>button:hover {
43
+ background-color: #c0c0c0;
44
+ border-color: #777;
45
  }
46
  .stSelectbox, .stSlider {
47
+ background-color: #fff;
48
+ color: #333;
49
+ border: 2px solid #999;
50
  border-radius: 5px;
51
  padding: 5px;
52
  }
53
  .stCheckbox label {
54
+ color: #333;
55
  font-size: 14px;
56
  font-weight: bold;
57
  }
58
  .control-bar {
59
+ background-color: #e0e0e0;
60
  padding: 10px;
61
+ border: 2px solid #999;
62
  border-radius: 5px;
63
  margin-bottom: 10px;
64
  }
65
  .panel {
66
+ background-color: #fff;
67
  padding: 10px;
68
+ border: 2px solid #999;
69
  border-radius: 5px;
70
  margin: 10px 0;
71
  }
72
  .stSelectbox label, .stSlider label {
73
+ color: #333;
74
  font-size: 12px;
75
  font-weight: bold;
76
  }
77
+ .play-stop {
78
+ background-color: #e0e0e0;
79
+ border: 2px solid #999;
80
+ border-radius: 5px;
81
+ padding: 5px;
82
+ margin-right: 10px;
83
+ }
84
  </style>
85
  """, unsafe_allow_html=True)
86
 
87
  # Session state initialization
88
  if "training" not in st.session_state:
89
  st.session_state.training = False
90
+ st.session_state.epoch = 0
91
  if "num_hidden_layers" not in st.session_state:
92
  st.session_state.num_hidden_layers = 2
93
  if "hidden_layer_neurons" not in st.session_state:
 
99
  st.session_state.clear()
100
  st.session_state.num_hidden_layers = 2
101
  st.session_state.hidden_layer_neurons = [4, 2]
102
+ st.session_state.training = False
103
+ st.session_state.epoch = 0
104
 
105
  # Two-row top control bar
106
  with st.container():
107
  st.markdown('<div class="control-bar">', unsafe_allow_html=True)
108
+ # Row 1: Play/Stop and Epoch
109
+ col_play, col_epoch, col1, col2, col3 = st.columns([1, 2, 2, 2, 2])
110
+ with col_play:
111
+ col_play1, col_play2 = st.columns([1, 1])
112
+ with col_play1:
113
+ if st.button("⏪", key="rewind", help="Rewind"):
114
+ pass # Placeholder for rewind functionality
115
+ with col_play2:
116
+ if st.button("▶️", key="play", on_click=lambda: setattr(st.session_state, "training", True)):
117
+ st.session_state.training = True
118
+ if st.button("⏹️", key="stop", on_click=lambda: setattr(st.session_state, "training", False)):
119
+ st.session_state.training = False
120
+ with col_epoch:
121
+ st.write(f"Epoch: {st.session_state.epoch:06d}")
122
  with col1:
123
+ learning_rate = st.selectbox("Learning Rate", [0.0001, 0.001, 0.03, 0.1, 0.3, 1], index=2)
124
  with col2:
125
+ activation = st.selectbox("Activation", ["ReLU", "Sigmoid", "Tanh"], index=2)
 
126
  with col3:
127
+ reg_type = st.selectbox("Regularization", ["None", "L1", "L2"], index=0)
128
+
129
+ # Row 2: Other controls
130
+ col4, col5, col6, col7, col8 = st.columns([2, 2, 2, 2, 2])
131
  with col4:
132
+ reg_rate = st.selectbox("Reg Rate", [0.0, 0.001, 0.01, 0.1, 1], index=0)
133
  with col5:
134
+ problem_type = st.selectbox("Problem Type", ["Classification", "Regression"])
 
 
 
135
  with col6:
136
+ dataset_options = {"Classification": ["Blobs", "Circles", "Spirals", "XOR"], "Regression": ["Sine Wave"]}
137
+ dataset_type = st.selectbox("Dataset", dataset_options[problem_type])
138
  with col7:
139
+ batch_size = st.slider("Batch Size", 1, 30, 10)
140
  with col8:
141
+ noise_level = st.slider("Noise", 0, 50, 0, step=5)
 
 
 
 
142
  st.markdown('</div>', unsafe_allow_html=True)
143
 
144
  # Dataset generation
 
181
  # Main layout
182
  col_left, col_center, col_right = st.columns([1, 2, 1])
183
 
184
+ # Left panel: Data and Features with Seaborn
185
  with col_left:
186
  st.markdown('<div class="panel">', unsafe_allow_html=True)
187
+ st.subheader("DATA")
188
+ st.write("Which dataset do you want to use?")
189
  fig, ax = plt.subplots(figsize=(3, 3))
190
  if problem_type == "Classification":
191
  sns.scatterplot(x=fv[:, 0], y=fv[:, 1], hue=cv, palette="coolwarm", edgecolor="k", alpha=0.7, ax=ax, legend=False)
192
+ plt.colorbar(ax.collections[0], ax=ax, label="Class Probability", shrink=0.5)
193
  else:
194
  sns.scatterplot(x=fv[:, 0], y=cv, color="blue", edgecolor="k", alpha=0.7, ax=ax)
195
+ ax.set_facecolor("#e6f3ff")
196
  ax.set_xticks([])
197
  ax.set_yticks([])
 
198
  st.pyplot(fig)
199
+ st.subheader("FEATURES")
200
+ st.write("Which properties do you want to feed in?")
201
  for feature in features.keys():
202
  st.checkbox(feature, value=feature in ["X1", "X2"], key=feature)
203
+ col_train, col_noise, col_batch = st.columns(3)
204
+ with col_train:
205
+ train_ratio = st.slider("Ratio of training to test data: 50%", 10, 90, 50, 10) / 100
206
+ with col_noise:
207
+ noise_level = st.slider("Noise: 0", 0, 50, 0, step=5)
208
+ with col_batch:
209
+ batch_size = st.slider("Batch size: 10", 1, 30, 10)
210
+ st.button("REGENERATE", key="regenerate")
211
+ st.checkbox("Show test data", key="show_test_data")
212
+ st.checkbox("Discretize output", key="discretize_output")
213
  st.markdown('</div>', unsafe_allow_html=True)
214
 
215
+ # Center panel: Horizontal ANN Visualization
216
  with col_center:
217
  st.markdown('<div class="panel">', unsafe_allow_html=True)
218
+ st.subheader("HIDDEN LAYERS")
219
 
220
  def draw_nn(features, hidden_neurons):
221
  G = nx.DiGraph()
 
229
  for node in layer:
230
  G.add_node(node, layer=layer_idx)
231
  if layer_idx == 0:
232
+ node_colors[node] = "#ff9f40" # Orange for input
233
  elif layer_idx == len(all_layers) - 1:
234
+ node_colors[node] = "#ff9f40" # Orange for output
235
  else:
236
+ node_colors[node] = "#40a0ff" # Blue for hidden
237
 
238
  for i in range(len(all_layers) - 1):
239
  for node1 in all_layers[i]:
240
  for node2 in all_layers[i + 1]:
241
+ G.add_edge(node1, node2, weight=np.random.uniform(0.1, 1.0)) # Random weight for thickness
242
 
243
  pos = nx.multipartite_layout(G, subset_key="layer", align="vertical")
244
  pos_rotated = {node: (-y, x) for node, (x, y) in pos.items()}
 
246
  pos_rotated[node] = (pos_rotated[node][0] * 2, pos_rotated[node][1] * 2)
247
 
248
  fig, ax = plt.subplots(figsize=(8, 4))
249
+ ax.set_facecolor("#f5f5f5")
250
+ edge_colors = [plt.cm.RdBu(G[u][v]['weight']) for u, v in G.edges()]
251
  nx.draw(
252
  G, pos_rotated,
253
  with_labels=True,
254
  node_color=[node_colors[node] for node in G.nodes()],
255
+ edge_color=edge_colors,
256
+ node_shape="s", # Square nodes
257
+ node_size=1200,
258
  font_size=8,
259
  font_color="black",
260
  font_weight="bold",
261
  edgecolors="black",
262
+ width=[G[u][v]['weight'] * 2 for u, v in G.edges()], # Vary thickness
263
+ arrows=False,
264
  ax=ax
265
  )
266
+ plt.title("Neural Network Structure", color="#333", fontsize=12, pad=10)
267
+
268
+ # Add + and - buttons for layers
269
+ col_plus, col_minus = st.columns([1, 1])
270
+ with col_plus:
271
+ st.button("+", key="add_layer", on_click=add_layer)
272
+ with col_minus:
273
+ st.button("−", key="remove_layer", on_click=remove_layer)
274
+
275
+ # Layer neuron controls
276
+ for i in range(st.session_state.num_hidden_layers):
277
+ col_dec, col_label, col_inc = st.columns([1, 2, 1])
278
+ with col_dec:
279
+ st.button("−", key=f"dec_{i}", on_click=decrease_neurons, args=(i,))
280
+ with col_label:
281
+ st.write(f"{st.session_state.hidden_layer_neurons[i]} neurons")
282
+ with col_inc:
283
+ st.button("+", key=f"inc_{i}", on_click=increase_neurons, args=(i,))
284
+
285
  return fig
286
 
287
  st.pyplot(draw_nn(selected_features, st.session_state.hidden_layer_neurons))
 
304
  if st.session_state.hidden_layer_neurons[i] > 1:
305
  st.session_state.hidden_layer_neurons[i] -= 1
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  st.markdown('</div>', unsafe_allow_html=True)
308
 
309
+ # Right panel: Output
310
  with col_right:
311
  st.markdown('<div class="panel">', unsafe_allow_html=True)
312
+ st.subheader("OUTPUT")
313
+ if st.session_state.training:
314
+ train_loss = 0.505 # Simulated value
315
+ test_loss = 0.513 # Simulated value
316
+ st.write(f"Training loss: {train_loss:.3f}")
317
+ st.write(f"Test loss: {test_loss:.3f}")
318
+ col1, col2 = st.columns(2)
319
+ with col1:
320
+ plt.figure(figsize=(3, 3))
321
+ if problem_type == "Classification":
322
+ y_pred_proba = model.predict(selected_data, verbose=0)
323
+ y_pred = (y_pred_proba > 0.5).astype(int).ravel()
324
+ try:
325
+ plot_decision_regions(selected_data[:, :2], cv, clf=model, legend=2)
326
+ except Exception as e:
327
+ st.warning(f"Decision region plot failed: {e}")
328
+ xx, yy = np.meshgrid(np.linspace(selected_data[:, 0].min(), selected_data[:, 0].max(), 100),
329
+ np.linspace(selected_data[:, 1].min(), selected_data[:, 1].max(), 100))
330
+ grid = np.c_[xx.ravel(), yy.ravel()]
331
+ Z = model.predict(grid, verbose=0)
332
+ Z = (Z > 0.5).astype(int).reshape(xx.shape)
333
+ plt.contourf(xx, yy, Z, alpha=0.3, cmap="coolwarm")
334
+ plt.scatter(selected_data[:, 0], selected_data[:, 1], c=cv, cmap="coolwarm", edgecolors="k", alpha=0.7)
335
+ else:
336
+ y_pred = model.predict(selected_data, verbose=0)
337
+ plt.scatter(selected_data[:, 0], cv, c="blue", alpha=0.5)
338
+ plt.plot(selected_data[:, 0], y_pred, "r-", linewidth=2)
339
+ plt.gca().set_facecolor("#e6f3ff")
340
+ plt.xticks([])
341
+ plt.yticks([])
342
+ st.pyplot(plt)
343
+ with col2:
344
+ fig, ax = plt.subplots(figsize=(3, 3))
345
+ ax.plot([1, 2, 3], [0.5, 0.5, 0.5], "b-", label="Train") # Simulated loss
346
+ ax.plot([1, 2, 3], [0.51, 0.51, 0.51], "r--", label="Val") # Simulated loss
347
+ ax.legend()
348
+ ax.set_facecolor("#e6f3ff")
349
+ st.pyplot(fig)
350
+ st.write("Colors shows data, neuron and weight values.")
351
+ st.markdown('</div>', unsafe_allow_html=True)
352
 
353
+ # Training logic (moved outside for clarity)
354
+ if st.session_state.training:
355
  def create_model(input_dim, neurons):
356
  model = Sequential()
357
  model.add(Input(shape=(input_dim,)))
 
372
  self.placeholder = st.empty()
373
 
374
  def on_epoch_end(self, epoch, logs=None):
375
+ st.session_state.epoch = epoch + 1
376
  self.losses["Epoch"].append(epoch + 1)
377
  self.losses["Train Loss"].append(logs["loss"])
378
  self.losses["Val Loss"].append(logs["val_loss"])
 
384
  y_pred_proba = self.model.predict(self.X, verbose=0)
385
  y_pred = (y_pred_proba > 0.5).astype(int).ravel()
386
  try:
387
+ plot_decision_regions(self.X[:, :2], self.y, clf=self.model, legend=2)
388
  except Exception as e:
389
  st.warning(f"Decision region plot failed: {e}")
390
  xx, yy = np.meshgrid(np.linspace(self.X[:, 0].min(), self.X[:, 0].max(), 100),
 
398
  y_pred = self.model.predict(self.X, verbose=0)
399
  plt.scatter(self.X[:, 0], self.y, c="blue", alpha=0.5)
400
  plt.plot(self.X[:, 0], y_pred, "r-", linewidth=2)
401
+ plt.gca().set_facecolor("#e6f3ff")
402
  plt.xticks([])
403
  plt.yticks([])
404
  st.pyplot(plt)
 
407
  ax.plot(self.losses["Epoch"], self.losses["Train Loss"], "b-", label="Train")
408
  ax.plot(self.losses["Epoch"], self.losses["Val Loss"], "r--", label="Val")
409
  ax.legend()
410
+ ax.set_facecolor("#e6f3ff")
411
  st.pyplot(fig)
412
 
413
+ model = create_model(len(selected_features), st.session_state.hidden_layer_neurons)
414
+ callback = OutputCallback(selected_data, cv)
415
+ model.fit(selected_data, cv, epochs=999999, batch_size=batch_size, validation_split=1-train_ratio, callbacks=[callback], verbose=0)