Update app.py
Browse files
app.py
CHANGED
|
@@ -80,15 +80,6 @@ def load_multi_digit_model():
|
|
| 80 |
single_digit_model = load_single_digit_model()
|
| 81 |
multi_digit_model = load_multi_digit_model()
|
| 82 |
|
| 83 |
-
# === Helper function to clean prediction ===
|
| 84 |
-
def clean_prediction(predicted_digits):
|
| 85 |
-
"""
|
| 86 |
-
Removes junk or padded digits like trailing 0s or 1s and keeps only valid 0–9 digits.
|
| 87 |
-
You can further tune this logic based on training patterns.
|
| 88 |
-
"""
|
| 89 |
-
digits = [str(d) for d in predicted_digits if 0 <= d <= 9]
|
| 90 |
-
return ''.join(digits)
|
| 91 |
-
|
| 92 |
# === Sidebar controls ===
|
| 93 |
st.sidebar.title("Canvas Settings")
|
| 94 |
drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
|
|
@@ -134,15 +125,16 @@ if canvas_result.image_data is not None:
|
|
| 134 |
img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
|
| 135 |
img = 255 - img # Invert colors
|
| 136 |
|
| 137 |
-
# Resize image to match model input
|
| 138 |
-
img_resized = cv2.resize(img, (
|
| 139 |
img_normalized = img_resized / 255.0
|
| 140 |
-
final_img = img_normalized.reshape(1, 28,
|
| 141 |
|
| 142 |
# === Choose which model to use based on the image size ===
|
| 143 |
# If image is more likely to be a single digit (e.g., smaller width), use the single digit model
|
| 144 |
if img_resized.shape[1] < 50: # This is an arbitrary threshold for width
|
| 145 |
model_to_use = single_digit_model
|
|
|
|
| 146 |
else:
|
| 147 |
model_to_use = multi_digit_model
|
| 148 |
|
|
@@ -152,7 +144,7 @@ if canvas_result.image_data is not None:
|
|
| 152 |
# For multi-digit model, decode and clean prediction
|
| 153 |
if model_to_use == multi_digit_model:
|
| 154 |
predicted_digits = [np.argmax(p[0]) for p in preds]
|
| 155 |
-
predicted_str =
|
| 156 |
else:
|
| 157 |
# For single digit model, directly decode
|
| 158 |
predicted_str = str(np.argmax(preds))
|
|
|
|
| 80 |
single_digit_model = load_single_digit_model()
|
| 81 |
multi_digit_model = load_multi_digit_model()
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
# === Sidebar controls ===
|
| 84 |
st.sidebar.title("Canvas Settings")
|
| 85 |
drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
|
|
|
|
| 125 |
img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
|
| 126 |
img = 255 - img # Invert colors
|
| 127 |
|
| 128 |
+
# Resize image to match the expected model input shape
|
| 129 |
+
img_resized = cv2.resize(img, (100, 28)) # Resize to match multi-digit model input shape
|
| 130 |
img_normalized = img_resized / 255.0
|
| 131 |
+
final_img = img_normalized.reshape(1, 28, 100, 1) # Adjust to model input shape
|
| 132 |
|
| 133 |
# === Choose which model to use based on the image size ===
|
| 134 |
# If image is more likely to be a single digit (e.g., smaller width), use the single digit model
|
| 135 |
if img_resized.shape[1] < 50: # This is an arbitrary threshold for width
|
| 136 |
model_to_use = single_digit_model
|
| 137 |
+
final_img = final_img.reshape(1, 28, 28, 1) # For single digit, reshape accordingly
|
| 138 |
else:
|
| 139 |
model_to_use = multi_digit_model
|
| 140 |
|
|
|
|
| 144 |
# For multi-digit model, decode and clean prediction
|
| 145 |
if model_to_use == multi_digit_model:
|
| 146 |
predicted_digits = [np.argmax(p[0]) for p in preds]
|
| 147 |
+
predicted_str = ''.join([str(d) for d in predicted_digits])
|
| 148 |
else:
|
| 149 |
# For single digit model, directly decode
|
| 150 |
predicted_str = str(np.argmax(preds))
|