Mpavan45 commited on
Commit
1a955b8
·
verified ·
1 Parent(s): afa4836

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -13
app.py CHANGED
@@ -80,15 +80,6 @@ def load_multi_digit_model():
80
  single_digit_model = load_single_digit_model()
81
  multi_digit_model = load_multi_digit_model()
82
 
83
- # === Helper function to clean prediction ===
84
- def clean_prediction(predicted_digits):
85
- """
86
- Removes junk or padded digits like trailing 0s or 1s and keeps only valid 0–9 digits.
87
- You can further tune this logic based on training patterns.
88
- """
89
- digits = [str(d) for d in predicted_digits if 0 <= d <= 9]
90
- return ''.join(digits)
91
-
92
  # === Sidebar controls ===
93
  st.sidebar.title("Canvas Settings")
94
  drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
@@ -134,15 +125,16 @@ if canvas_result.image_data is not None:
134
  img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
135
  img = 255 - img # Invert colors
136
 
137
- # Resize image to match model input dimensions
138
- img_resized = cv2.resize(img, (80, 28)) # Resize to match multi-digit model input shape
139
  img_normalized = img_resized / 255.0
140
- final_img = img_normalized.reshape(1, 28, 80, 1)
141
 
142
  # === Choose which model to use based on the image size ===
143
  # If image is more likely to be a single digit (e.g., smaller width), use the single digit model
144
  if img_resized.shape[1] < 50: # This is an arbitrary threshold for width
145
  model_to_use = single_digit_model
 
146
  else:
147
  model_to_use = multi_digit_model
148
 
@@ -152,7 +144,7 @@ if canvas_result.image_data is not None:
152
  # For multi-digit model, decode and clean prediction
153
  if model_to_use == multi_digit_model:
154
  predicted_digits = [np.argmax(p[0]) for p in preds]
155
- predicted_str = clean_prediction(predicted_digits)
156
  else:
157
  # For single digit model, directly decode
158
  predicted_str = str(np.argmax(preds))
 
80
  single_digit_model = load_single_digit_model()
81
  multi_digit_model = load_multi_digit_model()
82
 
 
 
 
 
 
 
 
 
 
83
  # === Sidebar controls ===
84
  st.sidebar.title("Canvas Settings")
85
  drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
 
125
  img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
126
  img = 255 - img # Invert colors
127
 
128
+ # Resize image to match the expected model input shape
129
+ img_resized = cv2.resize(img, (100, 28)) # Resize to match multi-digit model input shape
130
  img_normalized = img_resized / 255.0
131
+ final_img = img_normalized.reshape(1, 28, 100, 1) # Adjust to model input shape
132
 
133
  # === Choose which model to use based on the image size ===
134
  # If image is more likely to be a single digit (e.g., smaller width), use the single digit model
135
  if img_resized.shape[1] < 50: # This is an arbitrary threshold for width
136
  model_to_use = single_digit_model
137
+ final_img = final_img.reshape(1, 28, 28, 1) # For single digit, reshape accordingly
138
  else:
139
  model_to_use = multi_digit_model
140
 
 
144
  # For multi-digit model, decode and clean prediction
145
  if model_to_use == multi_digit_model:
146
  predicted_digits = [np.argmax(p[0]) for p in preds]
147
+ predicted_str = ''.join([str(d) for d in predicted_digits])
148
  else:
149
  # For single digit model, directly decode
150
  predicted_str = str(np.argmax(preds))