Mpavan45 commited on
Commit
c57143d
·
verified ·
1 Parent(s): 1f3d8ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -126,18 +126,16 @@ if canvas_result.image_data is not None:
126
  img = 255 - img # Invert colors
127
 
128
  # Resize image to match the expected model input shape
129
- img_resized = cv2.resize(img, (100, 28)) # Resize to match multi-digit model input shape
130
  img_normalized = img_resized / 255.0
131
- final_img = img_normalized.reshape(1, 28, 100, 1) # Adjust to model input shape
132
-
133
- # === Choose which model to use based on the image size ===
134
- # If image is more likely to be a single digit (e.g., smaller width), use the single digit model
135
- if img_resized.shape[1] < 50: # This is an arbitrary threshold for width
136
  model_to_use = single_digit_model
137
- final_img = final_img.reshape(1, 28, 28, 1) # For single digit, reshape accordingly
138
  else:
139
  model_to_use = multi_digit_model
140
-
141
  # Predict using the selected model
142
  preds = model_to_use.predict(final_img)
143
 
@@ -151,4 +149,3 @@ if canvas_result.image_data is not None:
151
 
152
  # Show prediction result
153
  st.markdown(f"### 🧠 Predicted Number: **{predicted_str}**")
154
-
 
126
  img = 255 - img # Invert colors
127
 
128
  # Resize image to match the expected model input shape
129
+ img_resized = cv2.resize(img, (28, 28)) # Resize to 28x28 for single digit model
130
  img_normalized = img_resized / 255.0
131
+ final_img = img_normalized.reshape(1, 28, 28, 1) # Adjust to model input shape
132
+
133
+ # Choose the model based on image characteristics (height/width ratio, etc.)
134
+ if img_resized.shape[1] < 50: # If the image is narrow, assume it's a single digit
 
135
  model_to_use = single_digit_model
 
136
  else:
137
  model_to_use = multi_digit_model
138
+
139
  # Predict using the selected model
140
  preds = model_to_use.predict(final_img)
141
 
 
149
 
150
  # Show prediction result
151
  st.markdown(f"### 🧠 Predicted Number: **{predicted_str}**")