Update app.py
Browse files
app.py
CHANGED
|
@@ -116,36 +116,47 @@ with col2:
|
|
| 116 |
st.subheader("Original Drawing")
|
| 117 |
st.image(canvas_result.image_data, use_column_width=True)
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
# === Image preprocessing and prediction ===
|
| 120 |
if canvas_result.image_data is not None:
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
st.subheader("Preprocessed Image & Prediction")
|
| 123 |
-
|
| 124 |
# Preprocess image
|
| 125 |
-
img = cv2.cvtColor(
|
| 126 |
img = 255 - img # Invert colors
|
| 127 |
-
|
| 128 |
-
#
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
if img_resized.shape[1] < 50: # If the image is narrow, assume it's a single digit
|
| 135 |
model_to_use = single_digit_model
|
|
|
|
|
|
|
| 136 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
model_to_use = multi_digit_model
|
| 138 |
-
|
| 139 |
-
# Predict using the selected model
|
| 140 |
-
preds = model_to_use.predict(final_img)
|
| 141 |
-
|
| 142 |
-
# For multi-digit model, decode and clean prediction
|
| 143 |
-
if model_to_use == multi_digit_model:
|
| 144 |
predicted_digits = [np.argmax(p[0]) for p in preds]
|
| 145 |
predicted_str = ''.join([str(d) for d in predicted_digits])
|
| 146 |
-
else:
|
| 147 |
-
# For single digit model, directly decode
|
| 148 |
-
predicted_str = str(np.argmax(preds))
|
| 149 |
|
| 150 |
-
# Show
|
| 151 |
st.markdown(f"### 🧠 Predicted Number: **{predicted_str}**")
|
|
|
|
| 116 |
st.subheader("Original Drawing")
|
| 117 |
st.image(canvas_result.image_data, use_column_width=True)
|
| 118 |
|
| 119 |
+
# === Image preprocessing and prediction ===
|
| 120 |
+
# Initialize session state for draw count
|
| 121 |
+
if "draw_count" not in st.session_state:
|
| 122 |
+
st.session_state.draw_count = 0
|
| 123 |
+
st.session_state.last_image = None
|
| 124 |
+
|
| 125 |
# === Image preprocessing and prediction ===
|
| 126 |
if canvas_result.image_data is not None:
|
| 127 |
+
current_image = canvas_result.image_data
|
| 128 |
+
|
| 129 |
+
# Compare current image with previous
|
| 130 |
+
if st.session_state.last_image is None or not np.array_equal(current_image, st.session_state.last_image):
|
| 131 |
+
st.session_state.draw_count += 1
|
| 132 |
+
st.session_state.last_image = current_image
|
| 133 |
+
|
| 134 |
+
st.markdown(f"### ✏️ Draw Count: {st.session_state.draw_count}")
|
| 135 |
+
|
| 136 |
st.subheader("Preprocessed Image & Prediction")
|
| 137 |
+
|
| 138 |
# Preprocess image
|
| 139 |
+
img = cv2.cvtColor(current_image.astype("uint8"), cv2.COLOR_RGBA2GRAY)
|
| 140 |
img = 255 - img # Invert colors
|
| 141 |
+
|
| 142 |
+
# === Model Selection Based on Draw Count ===
|
| 143 |
+
if st.session_state.draw_count == 1:
|
| 144 |
+
# Assume single-digit drawing
|
| 145 |
+
img_resized = cv2.resize(img, (28, 28))
|
| 146 |
+
img_normalized = img_resized / 255.0
|
| 147 |
+
final_img = img_normalized.reshape(1, 28, 28, 1)
|
|
|
|
| 148 |
model_to_use = single_digit_model
|
| 149 |
+
preds = model_to_use.predict(final_img)
|
| 150 |
+
predicted_str = str(np.argmax(preds))
|
| 151 |
else:
|
| 152 |
+
# Assume multi-digit drawing
|
| 153 |
+
img_resized = cv2.resize(img, (100, 28))
|
| 154 |
+
img_normalized = img_resized / 255.0
|
| 155 |
+
final_img = img_normalized.reshape(1, 28, 100, 1)
|
| 156 |
model_to_use = multi_digit_model
|
| 157 |
+
preds = model_to_use.predict(final_img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
predicted_digits = [np.argmax(p[0]) for p in preds]
|
| 159 |
predicted_str = ''.join([str(d) for d in predicted_digits])
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
+
# Show result
|
| 162 |
st.markdown(f"### 🧠 Predicted Number: **{predicted_str}**")
|