anamjafar6's picture
Update app.py
b856d6d verified
raw
history blame
6.37 kB
# -------------------------
# Handwritten Digit Recognition App (robust preprocessing)
# Built by Anam Jafar
# -------------------------
import streamlit as st
import numpy as np
import cv2
from PIL import Image
from tensorflow.keras.models import load_model
from streamlit_drawable_canvas import st_canvas
# Page config
st.set_page_config(page_title="Digit Recognition App", page_icon="πŸ”’", layout="wide")
# Background (professional)
st.markdown(
"""
<style>
.stApp {
background: linear-gradient(to right, #f8f9fa, #e3f2fd);
}
</style>
""",
unsafe_allow_html=True
)
# Load model (cached)
@st.cache_resource
def load_cnn_model():
return load_model("mnist_cnn.h5")
model = load_cnn_model()
# ---------------------
# Helper: preprocess PIL file uploads
# ---------------------
def preprocess_pil_file(file_or_pil_image):
"""
Accept either a file-like object from file_uploader or a PIL.Image.
Returns: preprocessed array shape (1,28,28,1), and a display PIL image (28x28)
"""
if not isinstance(file_or_pil_image, Image.Image):
img = Image.open(file_or_pil_image)
else:
img = file_or_pil_image
# convert to grayscale and resize
img = img.convert('L').resize((28, 28))
arr = np.array(img).astype('float32') / 255.0 # 0..1
# auto-invert if background is white and strokes are dark (we expect digit bright on dark background)
if arr.mean() > 0.5:
arr = 1.0 - arr
# ensure shape (1,28,28,1)
arr = arr.reshape(1, 28, 28, 1).astype('float32')
return arr, Image.fromarray((arr[0,:,:,0]*255).astype('uint8'))
# ---------------------
# Helper: preprocess canvas image (RGBA or RGB)
# ---------------------
def preprocess_canvas_image(image_data):
"""
image_data: HxWx4 (RGBA) or HxWx3 (RGB) numpy array from st_canvas.
Returns preprocessed array shape (1,28,28,1) and display PIL image.
"""
if image_data is None:
return None, None
# If values are float [0..255] -> convert to uint8
img_uint8 = image_data.astype('uint8')
# If has alpha channel (4), drop or composite with white background
if img_uint8.shape[2] == 4:
# composite alpha over white background
alpha = img_uint8[..., 3] / 255.0
rgb = img_uint8[..., :3].astype('float32')
white = np.ones_like(rgb) * 255.0
comp = (rgb * alpha[..., None] + white * (1 - alpha[..., None])).astype('uint8')
gray = cv2.cvtColor(comp, cv2.COLOR_RGB2GRAY)
else:
gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
# Resize to 28x28, normalize
small = cv2.resize(gray, (28, 28), interpolation=cv2.INTER_AREA).astype('float32') / 255.0
# auto-invert heuristic
if small.mean() > 0.5:
small = 1.0 - small
arr = small.reshape(1, 28, 28, 1).astype('float32')
display_img = Image.fromarray((small * 255).astype('uint8'))
return arr, display_img
# ---------------------
# UI: header & sidebar
# ---------------------
st.markdown("<h1 style='text-align:center;color:#0D47A1;'>πŸ”’ Handwritten Digit Recognizer</h1>", unsafe_allow_html=True)
st.write("Upload or draw a digit (0–9). The app will preprocess the image and predict the digit.")
st.markdown("---")
st.sidebar.header("πŸ“Œ Instructions")
st.sidebar.info(
"β€’ Upload PNG/JPG or draw a digit. \n"
"β€’ The app auto-preprocesses (grayscale, resize, normalize, invert if needed). \n"
"β€’ Predictions show digit + confidence & probability bar chart."
)
st.sidebar.markdown("---")
st.sidebar.write("πŸ‘©β€πŸ’» **About**: Built with ❀️ by **Anam Jafar**")
st.sidebar.write("[πŸ”— LinkedIn](https://www.linkedin.com/in/anam-jafar)")
# ---------------------
# FILE UPLOAD (multiple)
# ---------------------
uploaded_files = st.file_uploader(
"πŸ“‚ Upload digit images (single or multiple):",
type=["png", "jpg", "jpeg"],
accept_multiple_files=True
)
if uploaded_files:
st.subheader("πŸ“· Uploaded Images & Predictions")
# display in rows of up to 4 columns
max_cols = 4
for i in range(0, len(uploaded_files), max_cols):
row_files = uploaded_files[i:i+max_cols]
cols = st.columns(len(row_files))
for j, file in enumerate(row_files):
arr, display_img = preprocess_pil_file(file)
# Debug info (remove in production)
st.experimental_show({"shape": arr.shape, "min": float(arr.min()), "max": float(arr.max())}) # optional
# Predict
with st.spinner("Predicting..."):
pred = model.predict(arr)
probs = pred[0]
label = int(np.argmax(probs))
conf = float(np.max(probs))
with cols[j]:
st.image(display_img, caption=f"Pred: {label} ({conf*100:.1f}%)", use_column_width=True)
st.bar_chart(probs) # show probability distribution
# ---------------------
# DRAWING PAD
# ---------------------
st.subheader("πŸ–ŒοΈ Draw your digit here:")
canvas_result = st_canvas(
stroke_width=12,
stroke_color="#000000",
background_color="#FFFFFF",
width=280,
height=280,
drawing_mode="freedraw",
key="canvas",
)
if canvas_result is not None and canvas_result.image_data is not None:
arr, display_img = preprocess_canvas_image(canvas_result.image_data)
if arr is not None:
# Debug info (remove in production)
st.experimental_show({"canvas_shape": arr.shape, "min": float(arr.min()), "max": float(arr.max())})
with st.spinner("Predicting..."):
pred = model.predict(arr)
probs = pred[0]
label = int(np.argmax(probs))
conf = float(np.max(probs))
st.markdown(
f"""
<div style="padding:12px;border-radius:8px;background:#FFF3CD;text-align:center;">
<h2 style="color:#D32F2F;">🎯 Predicted Digit: {label}</h2>
<p>Confidence: {conf*100:.2f}%</p>
</div>
""",
unsafe_allow_html=True
)
st.image(display_img, caption="Preprocessed (28Γ—28) view", width=120)
st.bar_chart(probs)
# ---------------------
# Footer
# ---------------------
st.markdown("---")
st.markdown("<p style='text-align:center;'>Built with ❀️ using Streamlit & TensorFlow | By <b>Anam Jafar</b></p>", unsafe_allow_html=True)