Mpavan45 commited on
Commit
a5750ad
·
verified ·
1 Parent(s): c57143d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -21
app.py CHANGED
@@ -116,36 +116,47 @@ with col2:
116
  st.subheader("Original Drawing")
117
  st.image(canvas_result.image_data, use_column_width=True)
118
 
 
 
 
 
 
 
119
  # === Image preprocessing and prediction ===
120
  if canvas_result.image_data is not None:
121
- st.markdown("---")
 
 
 
 
 
 
 
 
122
  st.subheader("Preprocessed Image & Prediction")
123
-
124
  # Preprocess image
125
- img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
126
  img = 255 - img # Invert colors
127
-
128
- # Resize image to match the expected model input shape
129
- img_resized = cv2.resize(img, (28, 28)) # Resize to 28x28 for single digit model
130
- img_normalized = img_resized / 255.0
131
- final_img = img_normalized.reshape(1, 28, 28, 1) # Adjust to model input shape
132
-
133
- # Choose the model based on image characteristics (height/width ratio, etc.)
134
- if img_resized.shape[1] < 50: # If the image is narrow, assume it's a single digit
135
  model_to_use = single_digit_model
 
 
136
  else:
 
 
 
 
137
  model_to_use = multi_digit_model
138
-
139
- # Predict using the selected model
140
- preds = model_to_use.predict(final_img)
141
-
142
- # For multi-digit model, decode and clean prediction
143
- if model_to_use == multi_digit_model:
144
  predicted_digits = [np.argmax(p[0]) for p in preds]
145
  predicted_str = ''.join([str(d) for d in predicted_digits])
146
- else:
147
- # For single digit model, directly decode
148
- predicted_str = str(np.argmax(preds))
149
 
150
- # Show prediction result
151
  st.markdown(f"### 🧠 Predicted Number: **{predicted_str}**")
 
116
  st.subheader("Original Drawing")
117
  st.image(canvas_result.image_data, use_column_width=True)
118
 
119
+ # === Image preprocessing and prediction ===
120
+ # Initialize session state for draw count
121
+ if "draw_count" not in st.session_state:
122
+ st.session_state.draw_count = 0
123
+ st.session_state.last_image = None
124
+
125
  # === Image preprocessing and prediction ===
126
  if canvas_result.image_data is not None:
127
+ current_image = canvas_result.image_data
128
+
129
+ # Compare current image with previous
130
+ if st.session_state.last_image is None or not np.array_equal(current_image, st.session_state.last_image):
131
+ st.session_state.draw_count += 1
132
+ st.session_state.last_image = current_image
133
+
134
+ st.markdown(f"### ✏️ Draw Count: {st.session_state.draw_count}")
135
+
136
  st.subheader("Preprocessed Image & Prediction")
137
+
138
  # Preprocess image
139
+ img = cv2.cvtColor(current_image.astype("uint8"), cv2.COLOR_RGBA2GRAY)
140
  img = 255 - img # Invert colors
141
+
142
+ # === Model Selection Based on Draw Count ===
143
+ if st.session_state.draw_count == 1:
144
+ # Assume single-digit drawing
145
+ img_resized = cv2.resize(img, (28, 28))
146
+ img_normalized = img_resized / 255.0
147
+ final_img = img_normalized.reshape(1, 28, 28, 1)
 
148
  model_to_use = single_digit_model
149
+ preds = model_to_use.predict(final_img)
150
+ predicted_str = str(np.argmax(preds))
151
  else:
152
+ # Assume multi-digit drawing
153
+ img_resized = cv2.resize(img, (100, 28))
154
+ img_normalized = img_resized / 255.0
155
+ final_img = img_normalized.reshape(1, 28, 100, 1)
156
  model_to_use = multi_digit_model
157
+ preds = model_to_use.predict(final_img)
 
 
 
 
 
158
  predicted_digits = [np.argmax(p[0]) for p in preds]
159
  predicted_str = ''.join([str(d) for d in predicted_digits])
 
 
 
160
 
161
+ # Show result
162
  st.markdown(f"### 🧠 Predicted Number: **{predicted_str}**")