ANN / app.py
santosh7's picture
Update app.py
1ad540f verified
Raw
History Blame Contribute Delete
2.95 kB
# app.py
import streamlit as st
import tensorflow as tf
import numpy as np
from tensorflow.keras.utils import plot_model
import matplotlib.pyplot as plt
from io import BytesIO
st.title("TensorFlow Neural Network Playground")
st.sidebar.header("Network Configuration")
input_nodes = st.sidebar.slider("Input Layer Nodes", 2, 10, 4)
hidden_nodes = st.sidebar.slider("Hidden Layer Nodes", 2, 10, 6)
output_nodes = st.sidebar.slider("Output Layer Nodes", 1, 5, 2)
learning_rate = st.sidebar.slider("Learning Rate", 0.01, 1.0, 0.1)
def create_model(input_dim, hidden_dim, output_dim, lr):
model = tf.keras.Sequential([
tf.keras.layers.Dense(hidden_dim, activation='relu', input_shape=(input_dim,), name='Hidden_Layer'),
tf.keras.layers.Dense(output_dim, activation='softmax', name='Output_Layer')
])
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
loss='categorical_crossentropy',
metrics=['accuracy'])
return model
def plot_network(model):
img_data = BytesIO()
plot_model(model,
to_file=img_data,
show_shapes=True,
show_layer_names=True,
rankdir='TB',
expand_nested=True,
dpi=96)
img_data.seek(0)
st.image(img_data.getvalue(), caption="Neural Network Architecture")
model = create_model(input_nodes, hidden_nodes, output_nodes, learning_rate)
st.subheader("Network Architecture")
plot_network(model)
def generate_sample_data(samples=100):
X = np.random.random((samples, input_nodes))
y = np.random.randint(0, output_nodes, size=(samples,))
y = tf.keras.utils.to_categorical(y, output_nodes)
return X, y
if st.button("Train Model"):
X, y = generate_sample_data()
history = model.fit(X, y, epochs=10, verbose=0)
st.write("Training Complete!")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history['accuracy'])
ax1.set_title('Model Accuracy')
ax1.set_ylabel('Accuracy')
ax1.set_xlabel('Epoch')
ax2.plot(history.history['loss'])
ax2.set_title('Model Loss')
ax2.set_ylabel('Loss')
ax2.set_xlabel('Epoch')
st.pyplot(fig)
if st.checkbox("Show Model Summary"):
st.subheader("Model Summary")
summary_str = []
model.summary(print_fn=lambda x: summary_str.append(x))
st.markdown("### Model: sequential")
st.markdown("""
| Layer (type) | Output Shape | Param # |
|----------------------|----------------------|---------|""")
for line in summary_str[1:-2]:
if 'dense' in line.lower():
parts = line.split()
layer_name = parts[0] + " (Dense)"
output_shape = parts[1]
param_count = parts[2]
st.markdown(f"| {layer_name:<20} | {output_shape:<20} | {param_count:>7} |")
total_params = summary_str[-1].split()[-1]
st.markdown(f"**Total params:** {total_params}")