trohith89 commited on
Commit
ef99ba3
·
verified ·
1 Parent(s): f526557

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -1
app.py CHANGED
@@ -1,8 +1,31 @@
1
  import pandas as pd
2
- from PIL import Image
 
 
3
  import streamlit as st
4
  from streamlit_drawable_canvas import st_canvas
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  st.subheader("Draw Your Digit")
7
  canvas_result = st_canvas(
8
  fill_color="rgba(0, 0, 0, 0)",
@@ -16,3 +39,48 @@ canvas_result = st_canvas(
16
  key="canvas",
17
  display_toolbar=True,
18
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
+ import numpy as np
3
+ import plotly.express as px
4
+
5
  import streamlit as st
6
  from streamlit_drawable_canvas import st_canvas
7
 
8
+ import cv2
9
+
10
+ from keras.models import load_model
11
+
12
+ # --- App Configuration ---
13
+ st.set_page_config(page_title="Handwritten Digit Recognizer", layout="centered")
14
+ st.markdown("""
15
+ <style>
16
+ .stButton>button {background-color: #4b7bec; color: white; border-radius: 8px; padding: 10px;}
17
+ .stButton>button:hover {background-color: #3867d6;}
18
+ </style>
19
+ """, unsafe_allow_html=True)
20
+
21
+ # --- Load Model ---
22
+ @st.cache_resource
23
+ def load_digit_model():
24
+ return load_model("MNIST_Classifier.keras")
25
+
26
+
27
+
28
+ # --- Drawing Canvas ---
29
  st.subheader("Draw Your Digit")
30
  canvas_result = st_canvas(
31
  fill_color="rgba(0, 0, 0, 0)",
 
39
  key="canvas",
40
  display_toolbar=True,
41
  )
42
+
43
+ # --- Predict Button ---
44
+ predict_clicked = st.button("🔍 Predict", use_container_width=True, key="predict_button")
45
+
46
+ if predict_clicked and canvas_result.image_data is not None:
47
+ img = cv2.cvtColor(canvas_result.image_data.astype(np.uint8), cv2.COLOR_RGBA2GRAY)
48
+
49
+ if np.all(img == 255):
50
+ st.warning("⚠️ Please draw something before predicting!")
51
+ else:
52
+ img = 255 - img # Inverting the colors to mimic the dataset
53
+ img_resized = cv2.resize(img, (32, 32), interpolation=cv2.INTER_AREA)
54
+ img_normalized = img_resized.astype("float32") / 255.0
55
+ input_img = img_normalized.reshape(1, 32, 32, 1)
56
+
57
+ model = load_digit_model()
58
+ pred_probs = model.predict(input_img)
59
+ pred_class = np.argmax(pred_probs)
60
+ confidence = np.max(pred_probs)
61
+
62
+ st.subheader("Prediction Results")
63
+ col_img, col_result = st.columns([1, 2])
64
+
65
+ with col_img:
66
+ st.image(img_resized, caption="Processed Drawing", width=100, clamp=True)
67
+
68
+ with col_result:
69
+ st.success(f"🧠 Predicted Digit: **{pred_class}**")
70
+ st.info(f"🔍 Confidence: **{confidence * 100:.2f}%**")
71
+
72
+ # --- Plot probabilities ---
73
+ probs_df = pd.DataFrame({
74
+ "Digit": list(range(10)),
75
+ "Probability": pred_probs[0] * 100
76
+ })
77
+ fig = px.bar(probs_df, x="Digit", y="Probability",
78
+ title="Prediction Probabilities",
79
+ color="Probability",
80
+ color_continuous_scale="Blues",
81
+ height=300)
82
+ fig.update_layout(xaxis_title="Digit", yaxis_title="Probability (%)", xaxis=dict(tickmode="linear"))
83
+ st.plotly_chart(fig, use_container_width=True)
84
+
85
+ elif predict_clicked:
86
+ st.warning("⚠️ Please draw something before predicting!")