|
|
import streamlit as st |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
from tensorflow.keras.models import load_model |
|
|
from streamlit_drawable_canvas import st_canvas |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Digit Recognition App", |
|
|
page_icon="🔢", |
|
|
layout="wide" |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<style> |
|
|
.stApp { |
|
|
background: linear-gradient(to right, #f8f9fa, #e3f2fd); |
|
|
} |
|
|
</style> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_cnn_model(): |
|
|
return load_model("mnist_cnn.h5") |
|
|
|
|
|
model = load_cnn_model() |
|
|
|
|
|
|
|
|
def preprocess_pil_file(file_or_pil_image): |
|
|
if not isinstance(file_or_pil_image, Image.Image): |
|
|
img = Image.open(file_or_pil_image) |
|
|
else: |
|
|
img = file_or_pil_image |
|
|
|
|
|
img = img.convert('L').resize((28, 28)) |
|
|
arr = np.array(img).astype('float32') / 255.0 |
|
|
|
|
|
if arr.mean() > 0.5: |
|
|
arr = 1.0 - arr |
|
|
|
|
|
arr = arr.reshape(1, 28, 28, 1).astype('float32') |
|
|
return arr, Image.fromarray((arr[0, :, :, 0] * 255).astype('uint8')) |
|
|
|
|
|
def preprocess_canvas_image(image_data): |
|
|
if image_data is None: |
|
|
return None, None |
|
|
|
|
|
img_uint8 = image_data.astype('uint8') |
|
|
|
|
|
if img_uint8.shape[2] == 4: |
|
|
alpha = img_uint8[..., 3] / 255.0 |
|
|
rgb = img_uint8[..., :3].astype('float32') |
|
|
white = np.ones_like(rgb) * 255.0 |
|
|
comp = (rgb * alpha[..., None] + white * (1 - alpha[..., None])).astype('uint8') |
|
|
gray = cv2.cvtColor(comp, cv2.COLOR_RGB2GRAY) |
|
|
else: |
|
|
gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY) |
|
|
|
|
|
small = cv2.resize(gray, (28, 28), interpolation=cv2.INTER_AREA).astype('float32') / 255.0 |
|
|
|
|
|
if small.mean() > 0.5: |
|
|
small = 1.0 - small |
|
|
|
|
|
arr = small.reshape(1, 28, 28, 1).astype('float32') |
|
|
display_img = Image.fromarray((small * 255).astype('uint8')) |
|
|
return arr, display_img |
|
|
|
|
|
|
|
|
st.markdown("<h1 style='text-align:center;color:#0D47A1;'>🔢 Handwritten Digit Recognizer</h1>", unsafe_allow_html=True) |
|
|
st.write("Upload or draw a digit (0–9). The app will preprocess the image and predict the digit.") |
|
|
st.markdown("---") |
|
|
|
|
|
|
|
|
st.sidebar.header("📌 Instructions") |
|
|
st.sidebar.info( |
|
|
"• Upload PNG/JPG or draw a digit. \n" |
|
|
"• The app auto-preprocesses (grayscale, resize, normalize, invert if needed). \n" |
|
|
"• Predictions show digit + confidence & probability bar chart." |
|
|
) |
|
|
st.sidebar.markdown("---") |
|
|
st.sidebar.write("👩💻 **About**: Built with streamlit❤️ by **Anam Jafar**") |
|
|
st.sidebar.write("[🔗 LinkedIn](https://www.linkedin.com/in/anam-jafar6/)") |
|
|
|
|
|
|
|
|
uploaded_files = st.file_uploader( |
|
|
"📂 Upload digit images (single or multiple):", |
|
|
type=["png", "jpg", "jpeg"], |
|
|
accept_multiple_files=True |
|
|
) |
|
|
|
|
|
if uploaded_files: |
|
|
st.subheader("📷 Uploaded Images & Predictions") |
|
|
|
|
|
max_cols = 4 |
|
|
for i in range(0, len(uploaded_files), max_cols): |
|
|
row_files = uploaded_files[i:i+max_cols] |
|
|
cols = st.columns(len(row_files)) |
|
|
for j, file in enumerate(row_files): |
|
|
arr, display_img = preprocess_pil_file(file) |
|
|
pred = model.predict(arr) |
|
|
probs = pred[0] |
|
|
label = int(np.argmax(probs)) |
|
|
conf = float(np.max(probs)) |
|
|
|
|
|
with cols[j]: |
|
|
st.image(display_img, caption=f"Pred: {label} ({conf*100:.1f}%)", width=60) |
|
|
st.bar_chart(probs) |
|
|
|
|
|
|
|
|
st.subheader("🖌️ Draw your digit here:") |
|
|
canvas_result = st_canvas( |
|
|
stroke_width=12, |
|
|
stroke_color="#000000", |
|
|
background_color="#FFFFFF", |
|
|
width=280, |
|
|
height=280, |
|
|
drawing_mode="freedraw", |
|
|
key="canvas", |
|
|
) |
|
|
|
|
|
if canvas_result is not None and canvas_result.image_data is not None: |
|
|
arr, display_img = preprocess_canvas_image(canvas_result.image_data) |
|
|
if arr is not None: |
|
|
pred = model.predict(arr) |
|
|
probs = pred[0] |
|
|
label = int(np.argmax(probs)) |
|
|
conf = float(np.max(probs)) |
|
|
|
|
|
st.markdown( |
|
|
f""" |
|
|
<div style="padding:12px;border-radius:8px;background:#FFF3CD;text-align:center;"> |
|
|
<h2 style="color:#D32F2F;">🎯 Predicted Digit: {label}</h2> |
|
|
<p>Confidence: {conf*100:.2f}%</p> |
|
|
</div> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
st.image(display_img, caption="Preprocessed (28×28) view", width=60) |
|
|
st.bar_chart(probs) |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown( |
|
|
"<p style='text-align:center;'>Built with ❤️ using Streamlit & TensorFlow | By <b>Anam Jafar</b></p>", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|