trohith89 commited on
Commit
8a4c831
·
verified ·
1 Parent(s): 3769fa2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +374 -140
app.py CHANGED
@@ -1,156 +1,390 @@
1
  import streamlit as st
 
 
 
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  import seaborn as sns
5
- import graphviz
6
- import time
7
- import tensorflow as tf
8
- from tensorflow import keras
9
- from tensorflow.keras import layers, Sequential
10
- from tensorflow.keras.regularizers import l2
11
  from sklearn.model_selection import train_test_split
12
- from sklearn.datasets import make_moons, make_circles, make_classification, make_blobs
 
 
 
 
13
  from mlxtend.plotting import plot_decision_regions
14
- from sklearn.base import BaseEstimator, ClassifierMixin
15
 
16
- # Set Streamlit page title
17
- st.set_page_config(page_title="Neural Network Trainer", layout="wide")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # ================= Session State for Training Controls =================
20
- if "epoch" not in st.session_state:
21
- st.session_state.epoch = 0
22
- if "running" not in st.session_state:
23
- st.session_state.running = False
24
- if "train_loss_history" not in st.session_state:
25
- st.session_state.train_loss_history = []
26
- if "test_loss_history" not in st.session_state:
27
- st.session_state.test_loss_history = []
 
 
28
 
29
- # ================= TRAINING CONTROL PANEL (Top) =================
30
- st.markdown("### Training Controls")
31
- col1, col2, col3, col4, col5, col6, col7, col8, col9 = st.columns(9)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  with col1:
34
- if st.button("↩️ Reset"):
35
- st.session_state.epoch = 0
36
- st.session_state.running = False
37
- st.session_state.train_loss_history = []
38
- st.session_state.test_loss_history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  with col2:
40
- if st.button("▶️ Train"):
41
- st.session_state.running = True
 
 
 
 
 
 
 
 
 
42
  with col3:
43
- if st.button("⏸️ Pause"):
44
- st.session_state.running = False
45
- with col4:
46
- activation = st.selectbox("Activation", ["ReLU", "Sigmoid", "Tanh"], index=2)
47
- with col5:
48
- problem_type = st.selectbox("Problem Type", ["Classification", "Regression"])
49
- with col6:
50
- learning_rate = st.selectbox("Learning Rate", [0.0001, 0.001, 0.01, 0.03, 0.1])
51
- with col7:
52
- num_epochs = st.slider("Epochs", 1, 100, 10)
53
- with col8:
54
- batch_size = st.slider("Batch Size", 1, 100, 10)
55
- with col9:
56
- dataset_option = st.selectbox("Select Dataset", ["Moons", "Circles", "Blobs", "Classification"])
57
- noise = st.slider("Noise Level", 0.0, 0.5, 0.2)
58
-
59
- # ================= MAIN LAYOUT =================
60
- col_features, col_hidden, col_plot = st.columns([2, 2, 3])
61
-
62
- # ========== FEATURES PANEL (Left) ========== #
63
- with col_features:
64
- st.header("FEATURES")
65
- st.write("Select input features:")
66
- x1 = st.checkbox("X₁", value=True)
67
- x2 = st.checkbox("X₂", value=True)
68
-
69
- # ========== HIDDEN LAYERS PANEL (Middle) ========== #
70
- with col_hidden:
71
- st.header("HIDDEN LAYERS")
72
- hidden_layers = st.slider("Number of Hidden Layers", 1, 7, 2)
73
- neurons = [st.slider(f"Neurons in Layer {i+1}", 1, 20, 4) for i in range(hidden_layers)]
74
-
75
- # ================= DATASET GENERATION =================
76
- def generate_data():
77
- if dataset_option == "Moons":
78
- X, y = make_moons(n_samples=1000, noise=noise)
79
- elif dataset_option == "Circles":
80
- X, y = make_circles(n_samples=1000, noise=noise)
81
- elif dataset_option == "Blobs":
82
- X, y = make_blobs(n_samples=1000, centers=2, cluster_std=noise)
83
- else:
84
- X, y = make_classification(n_samples=1000, n_features=2, n_classes=2, n_clusters_per_class=1, n_redundant=0, flip_y=noise)
85
- return X, y
86
-
87
- X, y = generate_data()
88
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
89
-
90
- # ========== TRAINING ANN ========== #
91
- def build_ann():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  model = Sequential()
93
- model.add(layers.Input(shape=(X.shape[1],)))
94
- for units in neurons:
95
- model.add(layers.Dense(units=units, activation='tanh'))
96
- model.add(layers.Dense(units=1, activation='sigmoid', kernel_regularizer=l2(0.1)))
97
- model.compile(optimizer=keras.optimizers.Adam(learning_rate), loss="binary_crossentropy" if problem_type == "Classification" else "mse")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  return model
99
 
100
- class KerasClassifierWrapper(BaseEstimator, ClassifierMixin):
101
- def __init__(self, model):
102
- self.model = model
103
- def fit(self, X, y):
104
- self.model.fit(X, y, epochs=num_epochs, batch_size=batch_size, verbose=0)
105
- return self
106
- def predict(self, X):
107
- return (self.model.predict(X) > 0.5).astype(int).flatten()
108
-
109
- wrapper_model = None
110
- if st.session_state.running:
111
- model = build_ann()
112
- history = model.fit(X_train, y_train, epochs=num_epochs, batch_size=batch_size, validation_data=(X_test, y_test), verbose=0)
113
- st.session_state.train_loss_history = history.history["loss"]
114
- st.session_state.test_loss_history = history.history["val_loss"]
115
- wrapper_model = KerasClassifierWrapper(model)
116
-
117
- # ========== LOSS PLOT ========== #
118
- with col_plot:
119
- st.header("Loss Plot")
120
- fig, ax = plt.subplots()
121
- ax.plot(range(1, len(st.session_state.train_loss_history) + 1), st.session_state.train_loss_history, marker="o", label="Train Loss")
122
- ax.plot(range(1, len(st.session_state.test_loss_history) + 1), st.session_state.test_loss_history, marker="s", label="Test Loss")
123
- ax.set_title("Epoch vs. Loss")
124
- ax.set_xlabel("Epoch")
125
- ax.set_ylabel("Loss")
126
- ax.legend()
127
- st.pyplot(fig)
128
-
129
- # =================== DECISION REGION =================== #
130
- if problem_type == "Classification":
131
- fig, ax = plt.subplots()
132
- plot_decision_regions(X_train, y_train, clf=wrapper_model, ax=ax)
133
- ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, edgecolor='k')
134
- ax.set_title("Decision Region")
135
- st.pyplot(fig)
136
-
137
- # =================== DRAW NEURAL NETWORK =================== #
138
- def draw_neural_network():
139
- graph = graphviz.Digraph()
140
- graph.node("X1", "X₁", shape="circle", style="filled", fillcolor="lightblue")
141
- graph.node("X2", "X₂", shape="circle", style="filled", fillcolor="lightblue")
142
- prev_layer = ["X1", "X2"]
143
- for i, num_neurons in enumerate(neurons):
144
- current_layer = [f"H{i+1}{j+1}" for j in range(num_neurons)]
145
- for neuron in current_layer:
146
- graph.node(neuron, neuron, shape="circle", style="filled", fillcolor="lightyellow")
147
- for prev in prev_layer:
148
- for curr in current_layer:
149
- graph.edge(prev, curr)
150
- prev_layer = current_layer
151
- graph.node("Output", "Output", shape="circle", style="filled", fillcolor="lightgreen")
152
- for neuron in prev_layer:
153
- graph.edge(neuron, "Output")
154
- return graph
155
-
156
- st.graphviz_chart(draw_neural_network())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import networkx as nx
3
+
4
+ import pandas as pd
5
  import numpy as np
6
  import matplotlib.pyplot as plt
7
  import seaborn as sns
8
+
9
+ from IPython.display import clear_output
10
+ import io
11
+
12
+ import sklearn
 
13
  from sklearn.model_selection import train_test_split
14
+ from sklearn.metrics import log_loss
15
+ from sklearn.datasets import make_classification, make_circles
16
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
17
+
18
+ import mlxtend
19
  from mlxtend.plotting import plot_decision_regions
 
20
 
21
+ import keras
22
+ import tensorflow as tf
23
+ from keras.optimizers import SGD
24
+ from keras.models import Sequential
25
+ from keras.layers import Input, Dense
26
+ from keras.losses import BinaryCrossentropy
27
+ from keras.regularizers import l2, l1
28
+ from keras.callbacks import Callback
29
+
30
+ st.set_page_config(layout='wide')
31
+
32
+ # Session state for tracking training process
33
+ if "training" not in st.session_state:
34
+ st.session_state.training = False
35
+ if "num_hidden_layers" not in st.session_state:
36
+ st.session_state.num_hidden_layers = 0
37
+ if "hidden_layer_neurons" not in st.session_state:
38
+ st.session_state.hidden_layer_neurons = []
39
+ if "prev_params" not in st.session_state:
40
+ st.session_state.prev_params = {}
41
+
42
+ def reset_session():
43
+ st.session_state.clear()
44
+
45
+
46
+ st.title("Neural Network Playground")
47
+
48
+ # Sidebar for paramters
49
+ st.sidebar.title("Configure & Train Model")
50
+ problem_type = st.sidebar.selectbox("Problem Type", ["Classification", "Regression"])
51
+ dataset_type = st.sidebar.selectbox("Select Dataset Type", ["Circle", "Gaussian"])
52
+ learning_rate = st.sidebar.selectbox("Learning Rate", [0.00001,0.0001,0.001,0.01,0.03,0.1,0.3,1,3,10])
53
+ regularization_type = st.sidebar.selectbox("Regularization", ["None", "L1", "L2"])
54
+ regularization_rate = st.sidebar.selectbox("Regularization Rate", [0.0,0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1, 3, 10])
55
+ activation_function = st.sidebar.selectbox("Activation Function", ["ReLU", "Sigmoid", "Tanh"])
56
+ train_to_test_ratio = st.sidebar.slider("Train-to-Test Ratio (%)", 10, 90, 20, 10) / 100
57
+ noise_level_slider = st.sidebar.slider("Noise Level", 0, 50, step=5)
58
+ batch_size = st.sidebar.slider("Batch Size", 1, 30, 10)
59
+
60
+ if st.sidebar.button("🔄 Reset Session"):
61
+ reset_session()
62
+ st.rerun()
63
+
64
+ # min noise
65
+ min_noise = 0.09
66
+
67
+ # Scale the noise level to range [0.02, 0.2]
68
+ noise_level = min_noise + (noise_level_slider / 50) * (0.2 - min_noise)
69
 
70
+ # Store current parameter values in a dictionary
71
+ current_params = {
72
+ "dataset_type": dataset_type,
73
+ "learning_rate": learning_rate,
74
+ "regularization_type": regularization_type,
75
+ "regularization_rate": regularization_rate,
76
+ "activation_function": activation_function,
77
+ "train_to_test_ratio": train_to_test_ratio,
78
+ "batch_size": batch_size,
79
+ "noise_level": noise_level
80
+ }
81
 
82
+ # Normalize noise_level to the range [0, 1] for flip_y
83
+ flip_y = noise_level / 50
84
+ class_sep = max(2.0 - 1.5 * flip_y, 0.5) # Decreases as noise increases
85
+ cluster_std = min(1.0 + 3.0 * flip_y, 3.0) # Increases as noise increases
86
+
87
+ # Generate dataset based on selection
88
+ if dataset_type == "Gaussian":
89
+ fv, cv = make_classification(
90
+ n_samples=800,
91
+ n_features=2,
92
+ n_informative=2,
93
+ n_redundant=0,
94
+ n_repeated=0,
95
+ n_classes=2,
96
+ class_sep=class_sep,
97
+ flip_y=flip_y,
98
+ n_clusters_per_class=1,
99
+ )
100
+ else:
101
+ fv, cv = make_circles(
102
+ n_samples=800,
103
+ shuffle=True,
104
+ noise=noise_level,
105
+ factor=0.2
106
+ )
107
+
108
+ # Functions for modifying hidden layers
109
+ def add_layer():
110
+ if st.session_state.num_hidden_layers < 6:
111
+ st.session_state.num_hidden_layers += 1
112
+ st.session_state.hidden_layer_neurons.append(1)
113
+
114
+ def remove_layer():
115
+ if st.session_state.num_hidden_layers > 0:
116
+ st.session_state.num_hidden_layers -= 1
117
+ st.session_state.hidden_layer_neurons.pop()
118
+
119
+ # Functions for modifying neurons in each layer
120
+ def increase_neurons(layer_idx):
121
+ if st.session_state.hidden_layer_neurons[layer_idx] < 8:
122
+ st.session_state.hidden_layer_neurons[layer_idx] += 1
123
+
124
+ def decrease_neurons(layer_idx):
125
+ if st.session_state.hidden_layer_neurons[layer_idx] > 1:
126
+ st.session_state.hidden_layer_neurons[layer_idx] -= 1
127
+
128
+ col1, col2, col3 = st.columns([2, 2, 2])
129
 
130
  with col1:
131
+ st.subheader("Select Input Features")
132
+
133
+ # Compute new features
134
+ std = StandardScaler()
135
+ X = std.fit_transform(fv)
136
+ x1, x2 = X[:, 0], X[:, 1]
137
+ x1_squared, x2_squared = x1**2, x2**2 # Squared features
138
+
139
+ # Update feature selection
140
+ available_features = ["X1", "X2", "X1^2", "X2^2"]
141
+
142
+ st.markdown("""
143
+ <style>
144
+ div[data-testid="stCheckbox"] {
145
+ background-color: #252830;
146
+ border-radius: 8px;
147
+ padding: 8px;
148
+ margin-bottom: 5px;
149
+ color: white;
150
+ }
151
+ div[data-testid="stCheckbox"] label {
152
+ font-size: 16px;
153
+ font-weight: bold;
154
+ color: white;
155
+ }
156
+ </style>
157
+ """, unsafe_allow_html=True)
158
+
159
+ selected_features = [feature for feature in available_features if st.checkbox(feature, value = st.session_state.get(feature, feature in ["X1", "X2"]), key=feature)]
160
+ num_inputs = len(selected_features)
161
+
162
+ # Map feature names to actual values
163
+ feature_mapping = {
164
+ "X1": x1,
165
+ "X2": x2,
166
+ "X1^2": x1_squared,
167
+ "X2^2": x2_squared
168
+ }
169
+
170
  with col2:
171
+ # Visualize dataset
172
+ st.subheader("Dataset Preview")
173
+ fig, ax = plt.subplots(figsize=(3, 3))
174
+ scatter = ax.scatter(fv[:, 0], fv[:, 1], c=cv, cmap="coolwarm", edgecolors="k", alpha=0.7)
175
+ ax.set_xticks([])
176
+ ax.set_yticks([])
177
+ ax.set_facecolor("#f0f0f0")
178
+
179
+ st.pyplot(fig) # Display scatter plot
180
+
181
+ num_outputs = 1
182
  with col3:
183
+ st.subheader("Hidden Layers")
184
+ col1, col2 = st.columns([1, 1])
185
+ with col1:
186
+ st.button(" Add Layer", on_click=add_layer)
187
+ with col2:
188
+ st.button(" Remove Layer", on_click=remove_layer)
189
+
190
+ st.write("**Adjust Neurons in Each Layer:**")
191
+ for i in range(st.session_state.num_hidden_layers):
192
+ col1, col2, col3 = st.columns([1, 2, 1])
193
+ with col1:
194
+ st.button("➖", key=f"dec_neuron_{i}", on_click=decrease_neurons, args=(i,))
195
+ with col2:
196
+ st.markdown(f"**Layer {i+1}: {st.session_state.hidden_layer_neurons[i]} neurons**")
197
+ with col3:
198
+ st.button("➕", key=f"inc_neuron_{i}", on_click=increase_neurons, args=(i,))
199
+
200
+ # Stack selected features for training
201
+ selected_data = np.column_stack([feature_mapping[feature] for feature in selected_features])
202
+
203
+
204
+ # Function to draw the neural network visually
205
+ def draw_nn(selected_features, hidden_layer_neurons, num_outputs):
206
+ G = nx.DiGraph()
207
+
208
+ # Define layers dynamically
209
+ input_layer = selected_features # Match node names with feature names
210
+ hidden_layers = []
211
+ if st.session_state.num_hidden_layers > 0:
212
+ hidden_layers = [[f"hl{i+1}_{j+1}" for j in range(hidden_layer_neurons[i])] for i in range(st.session_state.num_hidden_layers)]
213
+ output_layer = ["y1"] # Single output neuron
214
+
215
+ layers = [input_layer] + hidden_layers + [output_layer]
216
+
217
+ # Add nodes and assign colors
218
+ node_colors = {}
219
+ input_color = "lightgreen"
220
+ hidden_color = "lightblue"
221
+ output_color = "salmon"
222
+
223
+ # Add nodes
224
+ # for layer_idx, layer in enumerate(layers):
225
+ # for node in layer:
226
+ # G.add_node(node, layer=layer_idx, edgecolors='black')
227
+ for layer_idx, layer in enumerate(layers):
228
+ for node in layer:
229
+ G.add_node(node, layer=layer_idx, edgecolors='black')
230
+ if layer_idx == 0:
231
+ node_colors[node] = input_color # Input layer
232
+ elif layer_idx == len(layers) - 1:
233
+ node_colors[node] = output_color # Output layer
234
+ else:
235
+ node_colors[node] = hidden_color # Hidden layers
236
+
237
+ # Add edges (fully connected between layers)
238
+ for i in range(len(layers) - 1):
239
+ for node1 in layers[i]:
240
+ for node2 in layers[i + 1]:
241
+ G.add_edge(node1, node2)
242
+
243
+ # Graph Layout
244
+ pos = nx.multipartite_layout(G, subset_key="layer")
245
+ fig, ax = plt.subplots(figsize=(12, 4))
246
+
247
+ # Style updates for TensorFlow Playground look
248
+ fig.patch.set_alpha(0)
249
+ ax.set_facecolor("#252830") # Dark background
250
+ ax.patch.set_alpha(1)
251
+
252
+ # Get color list
253
+ color_list = [node_colors[node] for node in G.nodes]
254
+
255
+ nx.draw(G, pos, with_labels=True, node_color=color_list, edge_color="white", edgecolors = "black",
256
+ node_size=800, font_size=7.5, ax=ax, width=0.4, font_color="black", font_weight="bold")
257
+
258
+ return fig
259
+
260
+ def create_ann_model(input_dim, hidden_layers, neurons_per_layer):
261
  model = Sequential()
262
+ model.add(Input(shape=(input_dim,))) # Input layer
263
+
264
+ reg = None
265
+ if regularization_type == "L1":
266
+ reg = l1(regularization_rate)
267
+ elif regularization_type == "L2":
268
+ reg = l2(regularization_rate)
269
+
270
+ # Add hidden layers
271
+ for neurons in neurons_per_layer:
272
+ model.add(Dense(neurons, activation=activation_function.lower(), kernel_regularizer=reg))
273
+
274
+ # Output layer
275
+ model.add(Dense(1, activation='sigmoid'))
276
+
277
+ # Compile the model with explicit learning rate
278
+ optimizer = SGD(learning_rate=learning_rate)
279
+ model.compile(
280
+ optimizer=optimizer,
281
+ loss=BinaryCrossentropy(),
282
+ metrics=['accuracy']
283
+ )
284
  return model
285
 
286
+ def plot_decision_boundary(model, X, y):
287
+ plt.figure(figsize=(6, 4))
288
+ plot_decision_regions(X, y, clf=model, legend=2)
289
+ #plt.title('Decision Boundary')
290
+ return plt
291
+
292
+ class LossPlotCallback(tf.keras.callbacks.Callback):
293
+ def __init__(self, X, y, display_epochs=10):
294
+ super().__init__()
295
+ self.loss_df = pd.DataFrame(columns=["Epoch", "Train Loss", "Val Loss"])
296
+ #self.display_epochs = display_epochs
297
+ self.X = X
298
+ self.y = y
299
+ self.plot_placeholder = st.empty() # SINGLE container to update dynamically
300
+
301
+ def on_epoch_end(self, epoch, logs=None):
302
+ # Append new train and validation loss values
303
+ new_row = pd.DataFrame({
304
+ "Epoch": [epoch + 1],
305
+ "Train Loss": [logs['loss']],
306
+ "Val Loss": [logs['val_loss']]
307
+ })
308
+ self.loss_df = pd.concat([self.loss_df, new_row], ignore_index=True)
309
+
310
+ with self.plot_placeholder.container():
311
+ col1, col2 = st.columns([1, 1])
312
+
313
+ # Left Column: Decision Surface
314
+ with col1:
315
+ st.write("### Decision Boundary")
316
+ fig1 = plot_decision_boundary(ann_model, selected_data, cv)
317
+ st.pyplot(fig1, clear_figure=True)
318
+
319
+
320
+ # Right Column: Loss Plot
321
+ with col2:
322
+ st.write("### Training vs Validation Loss")
323
+ fig2, ax = plt.subplots(figsize=(6, 4), dpi=100)
324
+ ax.plot(self.loss_df["Epoch"], self.loss_df["Train Loss"], marker='o', markersize=1, linestyle='-', color='b', label="Train Loss")
325
+
326
+ if "Val Loss" in self.loss_df.columns and self.loss_df["Val Loss"].notna().any():
327
+ ax.plot(self.loss_df["Epoch"], self.loss_df["Val Loss"], marker='s',markersize=1, linestyle='--', color='r', label="Val Loss")
328
+
329
+ ax.set_xlabel("Epochs", fontsize=12, fontweight='bold')
330
+ ax.set_ylabel("Loss", fontsize=12, fontweight='bold')
331
+
332
+ #ax.set_title("Training vs Validation Loss", fontsize=14, fontweight='bold')
333
+
334
+ ax.legend(fontsize=10)
335
+
336
+ ax.grid(True, linestyle='--', alpha=0.6)
337
+ ax.spines['top'].set_visible(False)
338
+ ax.spines['right'].set_visible(False)
339
+
340
+ ax.set_xticks(range(1, len(self.loss_df) + 1))
341
+ st.pyplot(fig2, clear_figure=True)
342
+
343
+
344
+ if current_params != st.session_state.prev_params:
345
+ st.session_state.training = False # Stop training when a parameter changes
346
+ st.session_state.prev_params = current_params
347
+
348
+ # Start/Stop Buttons
349
+ col1, col2 = st.columns([1, 1])
350
+ with col1:
351
+ if st.button("▶️ Start Training"):
352
+ st.session_state.training = True
353
+ st.session_state.model_trained = False
354
+
355
+ with col2:
356
+ if st.button("⏹️ Stop Training"):
357
+ st.session_state.training = False
358
+
359
+ # Render the neural network visualization
360
+ st.write("### Logical Structure of the Neural Network")
361
+ st.pyplot(draw_nn(selected_features, st.session_state.hidden_layer_neurons, num_outputs))
362
+
363
+ # Train Model if Start is clicked
364
+ if st.session_state.training:
365
+ # Train the model and track loss in a DataFrame
366
+ ann_model = create_ann_model(
367
+ len(selected_features),
368
+ st.session_state.num_hidden_layers,
369
+ st.session_state.hidden_layer_neurons
370
+ )
371
+
372
+ st.session_state.model_trained = True
373
+
374
+ loss_plot_callback = LossPlotCallback(X=selected_data, y=cv)
375
+
376
+ # Capture model summary
377
+ model_summary = io.StringIO()
378
+ ann_model.summary(print_fn=lambda x: model_summary.write(x + "\n"))
379
+
380
+ # Display ANN model summary in Streamlit
381
+ st.subheader("Artificial Neural Network Model Summary")
382
+ st.code(model_summary.getvalue(), language="plaintext")
383
+
384
+ history = ann_model.fit(
385
+ selected_data, cv,
386
+ epochs=999999,
387
+ validation_split=1-train_to_test_ratio,
388
+ batch_size=batch_size,
389
+ callbacks=[loss_plot_callback],
390
+ )