ASomeoneWhoInterestedWithAI commited on
Commit
077e064
·
verified ·
1 Parent(s): 51849ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -13
app.py CHANGED
@@ -129,7 +129,7 @@ def predict_digit(input_image):
129
  return default_output
130
 
131
  try:
132
- # Extract the background or composite layer from the Gradio Sketchpad dictionary
133
  if isinstance(input_image, dict):
134
  img_array = input_image.get("composite", None)
135
  if img_array is None:
@@ -140,28 +140,34 @@ def predict_digit(input_image):
140
  if img_array is None:
141
  return default_output
142
 
143
- # Extract channels safely
144
  if isinstance(img_array, np.ndarray) and img_array.ndim == 3:
145
- if img_array.shape[-1] == 4: # RGBA -> alpha channel
146
- grayscale = img_array[..., 3]
147
- else: # RGB -> luminance
 
148
  grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
149
  else:
150
- grayscale = img_array
151
 
152
- # Check if canvas is essentially empty
153
- if np.max(grayscale) < 5:
 
 
 
 
 
154
  return default_output
155
 
156
- # Ensure the background is black and the text is white (standard MNIST setup)
157
- # If your brush was black and canvas was white, invert it here:
158
- # grayscale = 255 - grayscale
159
 
160
- # Resize & normalize
161
  img = Image.fromarray(grayscale.astype(np.uint8), mode="L")
162
- img = img.resize((28, 28), Image.Image.Resampling.BILINEAR if hasattr(Image, 'Image') else Image.BILINEAR)
163
  tensor_img = transform_fn(img).unsqueeze(0).to(device)
164
 
 
165
  with torch.no_grad():
166
  outputs = model(tensor_img)
167
  probabilities = F.softmax(outputs, dim=1)[0]
@@ -172,6 +178,7 @@ def predict_digit(input_image):
172
  print(f"Prediction error: {e}")
173
  return default_output
174
 
 
175
  # --- GRADIO INTERFACE ---
176
  with gr.Blocks() as demo:
177
  gr.Markdown(
 
129
  return default_output
130
 
131
  try:
132
+ # 1. Handle Gradio Sketchpad dictionary output
133
  if isinstance(input_image, dict):
134
  img_array = input_image.get("composite", None)
135
  if img_array is None:
 
140
  if img_array is None:
141
  return default_output
142
 
143
+ # 2. Convert to Grayscale safely
144
  if isinstance(img_array, np.ndarray) and img_array.ndim == 3:
145
+ if img_array.shape[-1] == 4: # RGBA (Canvas often uses alpha)
146
+ # If background is transparent/white, alpha channel might be inverted
147
+ grayscale = img_array[..., 3]
148
+ else: # RGB -> Grayscale
149
  grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
150
  else:
151
+ grayscale = img_array.copy()
152
 
153
+ # 3. AUTO-INVERT: Ensure white digit on black background
154
+ # If the average pixel value is bright (> 127), the user drew dark text on light background.
155
+ if np.mean(grayscale) > 127:
156
+ grayscale = 255.0 - grayscale
157
+
158
+ # 4. Check if the canvas is empty
159
+ if np.max(grayscale) < 15:
160
  return default_output
161
 
162
+ # Debugging print to check what your model is actually receiving
163
+ print(f"Processed image shape: {grayscale.shape} | Max Val: {np.max(grayscale)} | Mean Val: {np.mean(grayscale):.2f}")
 
164
 
165
+ # 5. Convert to PIL, Resize, and Transform
166
  img = Image.fromarray(grayscale.astype(np.uint8), mode="L")
167
+ img = img.resize((28, 28), Image.Resampling.BILINEAR)
168
  tensor_img = transform_fn(img).unsqueeze(0).to(device)
169
 
170
+ # 6. Model Inference
171
  with torch.no_grad():
172
  outputs = model(tensor_img)
173
  probabilities = F.softmax(outputs, dim=1)[0]
 
178
  print(f"Prediction error: {e}")
179
  return default_output
180
 
181
+
182
  # --- GRADIO INTERFACE ---
183
  with gr.Blocks() as demo:
184
  gr.Markdown(