Mpavan45 commited on
Commit
6458591
Β·
verified Β·
1 Parent(s): 9b0af88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -129
app.py CHANGED
@@ -1,86 +1,10 @@
1
- # import streamlit as st
2
- # import cv2
3
- # from streamlit_drawable_canvas import st_canvas
4
- # from keras.models import load_model
5
- # import numpy as np
6
-
7
- # # Sidebar controls
8
- # st.sidebar.title("Canvas Settings")
9
- # drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
10
- # stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
11
- # stroke_color = st.sidebar.color_picker("Stroke color hex: ", "#000000") # black
12
- # bg_color = st.sidebar.color_picker("Background color hex: ", "#FFFFFF") # white
13
- # bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"])
14
- # realtime_update = st.sidebar.checkbox("Update in realtime", True)
15
-
16
- # # Load model with caching
17
- # @st.cache_resource
18
- # def load_mnist_model():
19
- # return load_model("mnist_model.keras")
20
-
21
- # model = load_mnist_model()
22
-
23
- # st.title("πŸ–ŒοΈ Mindist: Draw a Number, Predict Instantly")
24
-
25
- # # Create a two-column layout
26
- # col1, col2 = st.columns([1, 1])
27
-
28
- # with col1:
29
- # st.subheader("Draw Here πŸ‘‡")
30
- # canvas_result = st_canvas(
31
- # fill_color="rgba(255, 165, 0, 0.3)",
32
- # stroke_width=stroke_width,
33
- # stroke_color=stroke_color,
34
- # background_color=bg_color,
35
- # update_streamlit=realtime_update,
36
- # height=280,
37
- # width=280,
38
- # drawing_mode=drawing_mode,
39
- # key="canvas",
40
- # )
41
-
42
- # with col2:
43
- # if canvas_result.image_data is not None:
44
- # st.subheader("Original Drawing")
45
- # st.image(canvas_result.image_data, use_column_width=True)
46
-
47
- # # Below the two columns: Show preprocessing and prediction
48
- # if canvas_result.image_data is not None:
49
- # st.markdown("---")
50
- # st.subheader("Preprocessed Image & Prediction")
51
-
52
- # img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
53
- # img = 255 - img # Invert colors
54
- # img_resized = cv2.resize(img, (28, 28))
55
- # img_normalized = img_resized / 255.0
56
- # final_img = img_normalized.reshape(1, 28, 28, 1)
57
-
58
- # col3, col4 = st.columns([1, 1])
59
- # with col3:
60
- # st.image(img_resized, caption="28x28 Preprocessed", clamp=True, channels="GRAY")
61
- # with col4:
62
- # prediction = model.predict(final_img)
63
- # predicted_digit = np.argmax(prediction)
64
- # st.markdown(f"### 🧠 Predicted Digit: **{predicted_digit}**")
65
  import streamlit as st
66
  import cv2
67
- import numpy as np
68
- from keras.models import load_model
69
  from streamlit_drawable_canvas import st_canvas
 
 
70
 
71
- # === Load models ===
72
- @st.cache_resource
73
- def load_single_digit_model():
74
- return load_model("mnist_model.keras")
75
-
76
- @st.cache_resource
77
- def load_multi_digit_model():
78
- return load_model("best_model.keras") # multi-digit model
79
-
80
- single_digit_model = load_single_digit_model()
81
- multi_digit_model = load_multi_digit_model()
82
-
83
- # === Sidebar controls ===
84
  st.sidebar.title("Canvas Settings")
85
  drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
86
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
@@ -89,13 +13,18 @@ bg_color = st.sidebar.color_picker("Background color hex: ", "#FFFFFF") # white
89
  bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"])
90
  realtime_update = st.sidebar.checkbox("Update in realtime", True)
91
 
92
- # === Title ===
93
- st.title("πŸ–ŒοΈ Multi-Digit and Single-Digit Drawing: Predict Instantly")
 
 
 
 
94
 
95
- # === Create a two-column layout ===
 
 
96
  col1, col2 = st.columns([1, 1])
97
 
98
- # === Canvas for drawing ===
99
  with col1:
100
  st.subheader("Draw Here πŸ‘‡")
101
  canvas_result = st_canvas(
@@ -110,57 +39,26 @@ with col1:
110
  key="canvas",
111
  )
112
 
113
- # === Display original drawing ===
114
  with col2:
115
  if canvas_result.image_data is not None:
116
  st.subheader("Original Drawing")
117
  st.image(canvas_result.image_data, use_column_width=True)
118
 
119
- # === Image preprocessing and prediction ===
120
- # === Sidebar reset ===
121
- if st.sidebar.button("πŸ”„ Reset"):
122
- st.session_state.draw_count = 0
123
- st.session_state.last_image = None
124
- st.experimental_rerun()
125
-
126
- # === Initialize draw counter ===
127
- if "draw_count" not in st.session_state:
128
- st.session_state.draw_count = 0
129
- st.session_state.last_image = None
130
-
131
- # === Preprocess and predict ===
132
  if canvas_result.image_data is not None:
133
- current_image = canvas_result.image_data
134
-
135
- # Check if drawing has changed
136
- if st.session_state.last_image is None or not np.array_equal(current_image, st.session_state.last_image):
137
- st.session_state.draw_count += 1
138
- st.session_state.last_image = current_image
139
-
140
- st.markdown(f"### ✏️ Draw Count: {st.session_state.draw_count}")
141
  st.subheader("Preprocessed Image & Prediction")
142
-
143
- # Convert and preprocess
144
- img = cv2.cvtColor(current_image.astype("uint8"), cv2.COLOR_RGBA2GRAY)
145
  img = 255 - img # Invert colors
146
-
147
- if st.session_state.draw_count == 1:
148
- # Single digit
149
- img_resized = cv2.resize(img, (28, 28))
150
- img_normalized = img_resized / 255.0
151
- final_img = img_normalized.reshape(1, 28, 28, 1)
152
- model_to_use = single_digit_model
153
- preds = model_to_use.predict(final_img)
154
- predicted_str = str(np.argmax(preds))
155
- else:
156
- # Multi digit
157
- img_resized = cv2.resize(img, (100, 28))
158
- img_normalized = img_resized / 255.0
159
- final_img = img_normalized.reshape(1, 28, 100, 1)
160
- model_to_use = multi_digit_model
161
- preds = model_to_use.predict(final_img)
162
- predicted_digits = [np.argmax(p[0]) for p in preds]
163
- predicted_str = ''.join([str(d) for d in predicted_digits])
164
-
165
- # Output
166
- st.markdown(f"### 🧠 Predicted Number: **{predicted_str}**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import cv2
 
 
3
  from streamlit_drawable_canvas import st_canvas
4
+ from keras.models import load_model
5
+ import numpy as np
6
 
7
+ # Sidebar controls
 
 
 
 
 
 
 
 
 
 
 
 
8
  st.sidebar.title("Canvas Settings")
9
  drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
10
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 10)
 
13
  bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"])
14
  realtime_update = st.sidebar.checkbox("Update in realtime", True)
15
 
16
+ # Load model with caching
17
+ @st.cache_resource
18
+ def load_mnist_model():
19
+ return load_model("mnist_model.keras")
20
+
21
+ model = load_mnist_model()
22
 
23
+ st.title("πŸ–ŒοΈ Mindist: Draw a Number, Predict Instantly")
24
+
25
+ # Create a two-column layout
26
  col1, col2 = st.columns([1, 1])
27
 
 
28
  with col1:
29
  st.subheader("Draw Here πŸ‘‡")
30
  canvas_result = st_canvas(
 
39
  key="canvas",
40
  )
41
 
 
42
  with col2:
43
  if canvas_result.image_data is not None:
44
  st.subheader("Original Drawing")
45
  st.image(canvas_result.image_data, use_column_width=True)
46
 
47
+ # Below the two columns: Show preprocessing and prediction
 
 
 
 
 
 
 
 
 
 
 
 
48
  if canvas_result.image_data is not None:
49
+ st.markdown("---")
 
 
 
 
 
 
 
50
  st.subheader("Preprocessed Image & Prediction")
51
+
52
+ img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
 
53
  img = 255 - img # Invert colors
54
+ img_resized = cv2.resize(img, (28, 28))
55
+ img_normalized = img_resized / 255.0
56
+ final_img = img_normalized.reshape(1, 28, 28, 1)
57
+
58
+ col3, col4 = st.columns([1, 1])
59
+ with col3:
60
+ st.image(img_resized, caption="28x28 Preprocessed", clamp=True, channels="GRAY")
61
+ with col4:
62
+ prediction = model.predict(final_img)
63
+ predicted_digit = np.argmax(prediction)
64
+ st.markdown(f"### 🧠 Predicted Digit: **{predicted_digit}**")