FurqanIshaq's picture
Update app.py
4514a2b verified
import streamlit as st
import numpy as np
import tensorflow as tf
from tensorflow import keras
from PIL import Image
import plotly.graph_objects as go
# --- Load Models ---
@st.cache_resource
def load_models():
ann = keras.models.load_model("ann_mnist_model.h5")
cnn = keras.models.load_model("cnn_mnist_model.h5")
return ann, cnn
ann_model, cnn_model = load_models()
# --- App Configuration ---
st.set_page_config(page_title="MNIST Comparison Dashboard", page_icon="🧠", layout="wide")
# Sidebar navigation
st.sidebar.title("🧭 Navigation")
menu = st.sidebar.radio("Go to:", ["πŸ–ΌοΈ Digit Classifier", "πŸ“Š Model Comparison Dashboard"])
# --- MENU 1: Digit Classifier ---
if menu == "πŸ–ΌοΈ Digit Classifier":
st.title("🧠 MNIST Digit Classifier – ANN vs CNN")
st.markdown("Upload a **28Γ—28 grayscale image** to test both models side-by-side.")
uploaded = st.file_uploader("Upload a handwritten digit image", type=["png", "jpg", "jpeg"])
if uploaded:
image = Image.open(uploaded).convert("L").resize((28, 28))
img_array = np.array(image) / 255.0
ann_input = img_array.reshape(1, 28, 28)
cnn_input = img_array.reshape(1, 28, 28, 1)
ann_pred = np.argmax(ann_model.predict(ann_input), axis=1)[0]
cnn_pred = np.argmax(cnn_model.predict(cnn_input), axis=1)[0]
col1, col2 = st.columns([1, 2])
with col1:
st.image(image, caption="Uploaded Digit", width=150)
with col2:
st.metric("ANN Prediction", ann_pred)
st.metric("CNN Prediction", cnn_pred)
with st.expander("See prediction probabilities"):
ann_probs = ann_model.predict(ann_input)[0]
cnn_probs = cnn_model.predict(cnn_input)[0]
fig = go.Figure()
fig.add_trace(go.Bar(x=list(range(10)), y=ann_probs, name="ANN"))
fig.add_trace(go.Bar(x=list(range(10)), y=cnn_probs, name="CNN"))
fig.update_layout(title="Prediction Probability Comparison", xaxis_title="Digit", yaxis_title="Probability")
st.plotly_chart(fig, use_container_width=True)
# --- MENU 2: Model Comparison Dashboard ---
elif menu == "πŸ“Š Model Comparison Dashboard":
st.title("πŸ“Š ANN vs CNN – Performance Dashboard")
st.markdown("Explore a comprehensive comparison of **Artificial Neural Network (ANN)** and **Convolutional Neural Network (CNN)** on the MNIST dataset.")
col1, col2, col3 = st.columns(3)
col1.metric("ANN Test Accuracy", "97.8%", "Baseline")
col2.metric("CNN Test Accuracy", "99.2%", "+1.4%")
col3.metric("Training Data Size", "60,000 images")
st.markdown("---")
st.subheader("πŸ“ˆ Training History (Simulated Comparison)")
# Simulated example histories for visual purposes
epochs = np.arange(1, 11)
ann_acc = [0.89, 0.92, 0.94, 0.95, 0.96, 0.967, 0.972, 0.975, 0.978, 0.978]
cnn_acc = [0.93, 0.96, 0.97, 0.977, 0.982, 0.987, 0.989, 0.991, 0.992, 0.992]
ann_loss = [0.35, 0.28, 0.22, 0.19, 0.16, 0.14, 0.12, 0.10, 0.09, 0.09]
cnn_loss = [0.27, 0.19, 0.14, 0.11, 0.09, 0.07, 0.06, 0.05, 0.04, 0.04]
fig = go.Figure()
fig.add_trace(go.Scatter(x=epochs, y=ann_acc, mode='lines+markers', name='ANN Accuracy'))
fig.add_trace(go.Scatter(x=epochs, y=cnn_acc, mode='lines+markers', name='CNN Accuracy'))
fig.update_layout(title="Accuracy Comparison Over Epochs", xaxis_title="Epoch", yaxis_title="Accuracy")
st.plotly_chart(fig, use_container_width=True)
fig2 = go.Figure()
fig2.add_trace(go.Scatter(x=epochs, y=ann_loss, mode='lines+markers', name='ANN Loss'))
fig2.add_trace(go.Scatter(x=epochs, y=cnn_loss, mode='lines+markers', name='CNN Loss'))
fig2.update_layout(title="Loss Comparison Over Epochs", xaxis_title="Epoch", yaxis_title="Loss")
st.plotly_chart(fig2, use_container_width=True)
st.markdown("---")
st.subheader("🧩 Model Architecture Insights")
st.write("### ANN:")
st.code("""
Input Layer (28x28 β†’ 784)
Dense (128 neurons, ReLU)
Dropout (0.2)
Dense (10 neurons, Softmax)
""")
st.write("### CNN:")
st.code("""
Conv2D (32 filters, 3x3)
MaxPooling2D (2x2)
Conv2D (64 filters, 3x3)
MaxPooling2D (2x2)
Flatten
Dense (128 neurons, ReLU)
Dense (10 neurons, Softmax)
""")
st.markdown("---")
st.subheader("πŸ’‘ Summary of Findings")
st.success("βœ… CNN achieves higher accuracy and faster convergence due to spatial feature extraction from convolution layers.")
st.info("πŸ’‘ ANN performs well but lacks the spatial awareness that CNNs provide, making CNNs ideal for image-based tasks like MNIST.")