Mpavan45 commited on
Commit
0fdd3c4
·
verified ·
1 Parent(s): d1b0437

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -55
app.py CHANGED
@@ -1,64 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import cv2
3
- from streamlit_drawable_canvas import st_canvas
4
- from keras.models import load_model
5
  import numpy as np
 
 
6
 
7
- # Sidebar controls
8
- st.sidebar.title("Canvas Settings")
9
- drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
 
 
 
 
 
 
 
10
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
11
- stroke_color = st.sidebar.color_picker("Stroke color hex: ", "#000000") # black
12
- bg_color = st.sidebar.color_picker("Background color hex: ", "#FFFFFF") # white
13
- bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"])
14
  realtime_update = st.sidebar.checkbox("Update in realtime", True)
15
 
16
- # Load model with caching
17
- @st.cache_resource
18
- def load_mnist_model():
19
- return load_model("mnist_model.keras")
20
 
21
- model = load_mnist_model()
22
-
23
- st.title("🖌️ Mindist: Draw a Number, Predict Instantly")
24
-
25
- # Create a two-column layout
26
- col1, col2 = st.columns([1, 1])
27
-
28
- with col1:
29
- st.subheader("Draw Here 👇")
30
- canvas_result = st_canvas(
31
- fill_color="rgba(255, 165, 0, 0.3)",
32
- stroke_width=stroke_width,
33
- stroke_color=stroke_color,
34
- background_color=bg_color,
35
- update_streamlit=realtime_update,
36
- height=280,
37
- width=280,
38
- drawing_mode=drawing_mode,
39
- key="canvas",
40
- )
41
-
42
- with col2:
43
- if canvas_result.image_data is not None:
44
- st.subheader("Original Drawing")
45
- st.image(canvas_result.image_data, use_column_width=True)
46
-
47
- # Below the two columns: Show preprocessing and prediction
48
  if canvas_result.image_data is not None:
49
- st.markdown("---")
50
- st.subheader("Preprocessed Image & Prediction")
51
-
52
- img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
53
- img = 255 - img # Invert colors
54
- img_resized = cv2.resize(img, (28, 28))
55
- img_normalized = img_resized / 255.0
56
- final_img = img_normalized.reshape(1, 28, 28, 1)
57
-
58
- col3, col4 = st.columns([1, 1])
59
- with col3:
60
- st.image(img_resized, caption="28x28 Preprocessed", clamp=True, channels="GRAY")
61
- with col4:
62
- prediction = model.predict(final_img)
63
- predicted_digit = np.argmax(prediction)
64
- st.markdown(f"### 🧠 Predicted Digit: **{predicted_digit}**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import streamlit as st
2
+ # import cv2
3
+ # from streamlit_drawable_canvas import st_canvas
4
+ # from keras.models import load_model
5
+ # import numpy as np
6
+
7
+ # # Sidebar controls
8
+ # st.sidebar.title("Canvas Settings")
9
+ # drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
10
+ # stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
11
+ # stroke_color = st.sidebar.color_picker("Stroke color hex: ", "#000000") # black
12
+ # bg_color = st.sidebar.color_picker("Background color hex: ", "#FFFFFF") # white
13
+ # bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"])
14
+ # realtime_update = st.sidebar.checkbox("Update in realtime", True)
15
+
16
+ # # Load model with caching
17
+ # @st.cache_resource
18
+ # def load_mnist_model():
19
+ # return load_model("mnist_model.keras")
20
+
21
+ # model = load_mnist_model()
22
+
23
+ # st.title("🖌️ Mindist: Draw a Number, Predict Instantly")
24
+
25
+ # # Create a two-column layout
26
+ # col1, col2 = st.columns([1, 1])
27
+
28
+ # with col1:
29
+ # st.subheader("Draw Here 👇")
30
+ # canvas_result = st_canvas(
31
+ # fill_color="rgba(255, 165, 0, 0.3)",
32
+ # stroke_width=stroke_width,
33
+ # stroke_color=stroke_color,
34
+ # background_color=bg_color,
35
+ # update_streamlit=realtime_update,
36
+ # height=280,
37
+ # width=280,
38
+ # drawing_mode=drawing_mode,
39
+ # key="canvas",
40
+ # )
41
+
42
+ # with col2:
43
+ # if canvas_result.image_data is not None:
44
+ # st.subheader("Original Drawing")
45
+ # st.image(canvas_result.image_data, use_column_width=True)
46
+
47
+ # # Below the two columns: Show preprocessing and prediction
48
+ # if canvas_result.image_data is not None:
49
+ # st.markdown("---")
50
+ # st.subheader("Preprocessed Image & Prediction")
51
+
52
+ # img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
53
+ # img = 255 - img # Invert colors
54
+ # img_resized = cv2.resize(img, (28, 28))
55
+ # img_normalized = img_resized / 255.0
56
+ # final_img = img_normalized.reshape(1, 28, 28, 1)
57
+
58
+ # col3, col4 = st.columns([1, 1])
59
+ # with col3:
60
+ # st.image(img_resized, caption="28x28 Preprocessed", clamp=True, channels="GRAY")
61
+ # with col4:
62
+ # prediction = model.predict(final_img)
63
+ # predicted_digit = np.argmax(prediction)
64
+ # st.markdown(f"### 🧠 Predicted Digit: **{predicted_digit}**")
65
  import streamlit as st
66
  import cv2
 
 
67
  import numpy as np
68
+ from streamlit_drawable_canvas import st_canvas
69
+ from tensorflow.keras.models import load_model
70
 
71
+ # === Load model (trained on single digits) ===
72
+ @st.cache_resource
73
+ def load_digit_model():
74
+ return load_model("mnist_model.keras")
75
+
76
+ model = load_digit_model()
77
+
78
+ # === Sidebar Controls ===
79
+ st.sidebar.title("🛠️ Settings")
80
+ mode = st.sidebar.radio("Choose Prediction Mode", ("Single Digit", "Multi Digit"))
81
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
82
+ stroke_color = st.sidebar.color_picker("Stroke color: ", "#000000")
83
+ bg_color = st.sidebar.color_picker("Background color: ", "#FFFFFF")
 
84
  realtime_update = st.sidebar.checkbox("Update in realtime", True)
85
 
86
+ # === Title ===
87
+ st.title("🧠 MNIST Digit Recognizer")
88
+ st.caption("Draw digits and let the model predict them! Choose **Single** or **Multi** mode from the sidebar.")
 
89
 
90
+ # === Drawing Canvas ===
91
+ canvas_result = st_canvas(
92
+ fill_color="rgba(255, 165, 0, 0.3)",
93
+ stroke_width=stroke_width,
94
+ stroke_color=stroke_color,
95
+ background_color=bg_color,
96
+ update_streamlit=realtime_update,
97
+ height=280,
98
+ width=280,
99
+ drawing_mode="freedraw",
100
+ key="canvas",
101
+ )
102
+
103
+ # === Prediction Logic ===
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  if canvas_result.image_data is not None:
105
+ st.subheader("🖼️ Original Drawing")
106
+ st.image(canvas_result.image_data, use_column_width=False, width=280)
107
+
108
+ img = canvas_result.image_data.astype('uint8')
109
+ gray = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
110
+ gray = 255 - gray # Invert for white digits on black
111
+ _, thresh = cv2.threshold(gray, 30, 255, cv2.THRESH_BINARY)
112
+
113
+ if mode == "Single Digit":
114
+ st.subheader("🔢 Single Digit Mode")
115
+ resized = cv2.resize(thresh, (28, 28))
116
+ normalized = resized.astype("float32") / 255.0
117
+ input_tensor = normalized.reshape(1, 28, 28, 1)
118
+ pred = model.predict(input_tensor, verbose=0)
119
+ digit = np.argmax(pred)
120
+ st.image(resized, width=100, caption="🧼 Cleaned & Resized")
121
+ st.success(f"✅ Predicted Digit: **{digit}**")
122
+
123
+ elif mode == "Multi Digit":
124
+ st.subheader("🔢 Multi Digit Mode")
125
+
126
+ # Detect contours
127
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
128
+ boxes = [cv2.boundingRect(c) for c in contours if cv2.contourArea(c) > 50]
129
+ boxes = sorted(boxes, key=lambda b: b[0]) # sort by x
130
+
131
+ predictions = []
132
+ for x, y, w, h in boxes:
133
+ digit_img = thresh[y:y+h, x:x+w]
134
+ digit_img = cv2.resize(digit_img, (28, 28))
135
+ normalized = digit_img.astype("float32") / 255.0
136
+ input_tensor = normalized.reshape(1, 28, 28, 1)
137
+ pred = model.predict(input_tensor, verbose=0)
138
+ digit = np.argmax(pred)
139
+ predictions.append((digit_img, digit))
140
+
141
+ if predictions:
142
+ st.markdown("### ✂️ Segmented Digits & Predictions")
143
+ cols = st.columns(len(predictions))
144
+ for i, (img, digit) in enumerate(predictions):
145
+ with cols[i]:
146
+ st.image(img, width=64, caption=f"➡️ {digit}")
147
+ full_number = ''.join(str(d) for (_, d) in predictions)
148
+ st.success(f"📌 Final Multi-Digit Prediction: **{full_number}**")
149
+ else:
150
+ st.warning("⚠️ Couldn't detect any digits. Try writing more clearly.")