Harika22 commited on
Commit
caa1208
·
verified ·
1 Parent(s): e141f54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py CHANGED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_drawable_canvas import st_canvas
3
+ from keras.models import load_model
4
+ import numpy as np
5
+ import cv2
6
+ from PIL import Image
7
+
8
+ # Unique title for the app
9
+ st.title("Handwritten Digit Recognizer")
10
+
11
+ # Sidebar controls
12
+ drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
13
+ stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
14
+ stroke_color = st.sidebar.color_picker("Stroke color hex: ","#FFFFFF")
15
+ bg_color = st.sidebar.color_picker("Background color hex: ","#000000" )
16
+ uploaded_digit_image = st.sidebar.file_uploader("Upload a digit image (for prediction):", type=["png", "jpg"])
17
+ realtime_update = st.sidebar.checkbox("Update in realtime", True)
18
+
19
+ # Load model
20
+ @st.cache_resource
21
+ def load_mnist_model():
22
+ return load_model('digit_reco.keras')
23
+
24
+ model = load_mnist_model()
25
+
26
+ # Improved Preprocessing function
27
+ def preprocess_image(image_data):
28
+ img_rgb = cv2.cvtColor(image_data.astype("uint8"), cv2.COLOR_RGBA2RGB)
29
+ gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
30
+ inverted = cv2.bitwise_not(gray)
31
+ _, binary = cv2.threshold(inverted, 50, 255, cv2.THRESH_BINARY)
32
+ coords = cv2.findNonZero(binary)
33
+ if coords is not None:
34
+ x, y, w, h = cv2.boundingRect(coords)
35
+ cropped = binary[y:y+h, x:x+w]
36
+ else:
37
+ return np.zeros((28, 28))
38
+ height, width = cropped.shape
39
+ if height > width:
40
+ new_height = 20
41
+ new_width = int(width * (20.0 / height))
42
+ else:
43
+ new_width = 20
44
+ new_height = int(height * (20.0 / width))
45
+ resized = cv2.resize(cropped, (new_width, new_height), interpolation=cv2.INTER_AREA)
46
+ padded = np.zeros((28, 28), dtype=np.uint8)
47
+ x_offset = (28 - resized.shape[1]) // 2
48
+ y_offset = (28 - resized.shape[0]) // 2
49
+ padded[y_offset:y_offset+resized.shape[0], x_offset:x_offset+resized.shape[1]] = resized
50
+ normalized = padded / 255.0
51
+ return normalized
52
+
53
+ # Handle uploaded image prediction
54
+ if uploaded_digit_image is not None:
55
+ st.subheader("📤 Uploaded Image Prediction")
56
+ image = Image.open(uploaded_digit_image).convert("RGBA")
57
+ image = image.resize((280, 280))
58
+ img_array = np.array(image)
59
+ st.image(img_array, caption="Uploaded Image")
60
+
61
+ processed_image = preprocess_image(img_array)
62
+ st.image(processed_image, width=150, caption="Processed Input (28x28)")
63
+
64
+ img_reshaped = processed_image.reshape(1, 28, 28, 1)
65
+ prediction = model.predict(img_reshaped)
66
+ predicted_digit = int(np.argmax(prediction))
67
+
68
+ st.markdown(
69
+ f"""
70
+ <div style='text-align: center; font-size: 60px; font-weight: bold; color: #2E86C1;
71
+ text-shadow: 2px 2px 4px #aaa; margin-top: 20px;'>
72
+ Predicted Digit: {predicted_digit}
73
+ </div>
74
+ """,
75
+ unsafe_allow_html=True
76
+ )
77
+
78
+ # Handle canvas drawing
79
+ canvas_result = st_canvas(
80
+ fill_color="rgba(255, 165, 0, 0.3)",
81
+ stroke_width=stroke_width,
82
+ stroke_color=stroke_color,
83
+ background_color=bg_color,
84
+ update_streamlit=realtime_update,
85
+ height=280,
86
+ width=280,
87
+ drawing_mode=drawing_mode,
88
+ key="canvas",
89
+ )
90
+
91
+ if canvas_result.image_data is not None:
92
+ st.subheader("✏️ Canvas Drawing Prediction")
93
+ st.image(canvas_result.image_data, caption="Original Drawing")
94
+ processed_image = preprocess_image(canvas_result.image_data)
95
+ st.image(processed_image, width=150, caption="Processed Input (28x28)")
96
+
97
+ img_reshaped = processed_image.reshape(1, 28, 28, 1)
98
+ prediction = model.predict(img_reshaped)
99
+ predicted_digit = int(np.argmax(prediction))
100
+
101
+ st.markdown(
102
+ f"""
103
+ <div style='text-align: center; font-size: 60px; font-weight: bold; color: #2E86C1;
104
+ text-shadow: 2px 2px 4px #aaa; margin-top: 20px;'>
105
+ Predicted Digit: {predicted_digit}
106
+ </div>
107
+ """,
108
+ unsafe_allow_html=True
109
+ )