satya11 commited on
Commit
5665daf
·
verified ·
1 Parent(s): b1218d2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +78 -0
src/streamlit_app.py CHANGED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_drawable_canvas import st_canvas
3
+ from keras.models import load_model
4
+ import numpy as np
5
+ import cv2
6
+
7
+ # Streamlit page setup
8
+ st.set_page_config(page_title="Digit AI", layout="centered")
9
+
10
+ # App title and intro
11
+ st.markdown("""
12
+ <style>
13
+ .canvas-wrapper {
14
+ border: 2px dashed #aaa;
15
+ padding: 10px;
16
+ margin-bottom: 10px;
17
+ }
18
+ .prediction-box {
19
+ font-size: 28px;
20
+ font-weight: bold;
21
+ margin-top: 10px;
22
+ }
23
+ .emoji {
24
+ font-size: 48px;
25
+ }
26
+ </style>
27
+ """, unsafe_allow_html=True)
28
+
29
+ st.markdown("<h1>Digit Recognizer</h1>", unsafe_allow_html=True)
30
+ st.markdown("<p>Draw a digit (0–9) below and see what the AI thinks it is!</p>", unsafe_allow_html=True)
31
+
32
+ # Sidebar: Drawing settings
33
+ st.sidebar.markdown("### ✏️ Drawing Settings")
34
+ drawing_mode = st.sidebar.selectbox("Tool", ("freedraw", "line", "rect", "circle", "transform"))
35
+ stroke_width = st.sidebar.slider("Stroke Width", 1, 25, 10)
36
+ stroke_color = st.sidebar.color_picker("Stroke Color", "#FFFFFF")
37
+ bg_color = st.sidebar.color_picker("Background Color", "#000000")
38
+ realtime_update = st.sidebar.checkbox("Update Realtime", True)
39
+
40
+ # Load model
41
+ @st.cache_resource
42
+ def load_mnist_model():
43
+ return load_model("digit_recognization.keras")
44
+
45
+ model = load_mnist_model()
46
+
47
+ # Drawing canvas
48
+ st.markdown('<div class="canvas-wrapper">', unsafe_allow_html=True)
49
+ canvas_result = st_canvas(
50
+ fill_color="rgba(255, 255, 255, 0.05)",
51
+ stroke_width=stroke_width,
52
+ stroke_color=stroke_color,
53
+ background_color=bg_color,
54
+ update_streamlit=realtime_update,
55
+ height=280,
56
+ width=280,
57
+ drawing_mode=drawing_mode,
58
+ key="canvas"
59
+ )
60
+ st.markdown('</div>', unsafe_allow_html=True)
61
+
62
+ # Process and predict
63
+ if canvas_result.image_data is not None:
64
+ # Convert RGBA to grayscale
65
+ img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
66
+ img_resized = cv2.resize(img, (28, 28))
67
+ img_normalized = img_resized / 255.0
68
+ img_reshaped = img_normalized.reshape((1, 28, 28, 1)) # add channel dimension
69
+
70
+ # Skip blank images
71
+ if np.sum(img_resized) > 10:
72
+ prediction = model.predict(img_reshaped, verbose=0)
73
+ predicted_digit = np.argmax(prediction)
74
+
75
+ st.markdown(f"<div class='prediction-box'>Prediction: {predicted_digit}</div>", unsafe_allow_html=True)
76
+ st.markdown(f"<div class='emoji'>{['0️⃣','1️⃣','2️⃣','3️⃣','4️⃣','5️⃣','6️⃣','7️⃣','8️⃣','9️⃣'][predicted_digit]}</div>", unsafe_allow_html=True)
77
+ else:
78
+ st.warning("Please draw a digit before predicting.")