Update app.py
Browse files
app.py
CHANGED
|
@@ -129,7 +129,7 @@ def predict_digit(input_image):
|
|
| 129 |
return default_output
|
| 130 |
|
| 131 |
try:
|
| 132 |
-
#
|
| 133 |
if isinstance(input_image, dict):
|
| 134 |
img_array = input_image.get("composite", None)
|
| 135 |
if img_array is None:
|
|
@@ -140,28 +140,34 @@ def predict_digit(input_image):
|
|
| 140 |
if img_array is None:
|
| 141 |
return default_output
|
| 142 |
|
| 143 |
-
#
|
| 144 |
if isinstance(img_array, np.ndarray) and img_array.ndim == 3:
|
| 145 |
-
if img_array.shape[-1] == 4: # RGBA
|
| 146 |
-
|
| 147 |
-
|
|
|
|
| 148 |
grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
|
| 149 |
else:
|
| 150 |
-
grayscale = img_array
|
| 151 |
|
| 152 |
-
#
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
return default_output
|
| 155 |
|
| 156 |
-
#
|
| 157 |
-
|
| 158 |
-
# grayscale = 255 - grayscale
|
| 159 |
|
| 160 |
-
# Resize
|
| 161 |
img = Image.fromarray(grayscale.astype(np.uint8), mode="L")
|
| 162 |
-
img = img.resize((28, 28), Image.
|
| 163 |
tensor_img = transform_fn(img).unsqueeze(0).to(device)
|
| 164 |
|
|
|
|
| 165 |
with torch.no_grad():
|
| 166 |
outputs = model(tensor_img)
|
| 167 |
probabilities = F.softmax(outputs, dim=1)[0]
|
|
@@ -172,6 +178,7 @@ def predict_digit(input_image):
|
|
| 172 |
print(f"Prediction error: {e}")
|
| 173 |
return default_output
|
| 174 |
|
|
|
|
| 175 |
# --- GRADIO INTERFACE ---
|
| 176 |
with gr.Blocks() as demo:
|
| 177 |
gr.Markdown(
|
|
|
|
| 129 |
return default_output
|
| 130 |
|
| 131 |
try:
|
| 132 |
+
# 1. Handle Gradio Sketchpad dictionary output
|
| 133 |
if isinstance(input_image, dict):
|
| 134 |
img_array = input_image.get("composite", None)
|
| 135 |
if img_array is None:
|
|
|
|
| 140 |
if img_array is None:
|
| 141 |
return default_output
|
| 142 |
|
| 143 |
+
# 2. Convert to Grayscale safely
|
| 144 |
if isinstance(img_array, np.ndarray) and img_array.ndim == 3:
|
| 145 |
+
if img_array.shape[-1] == 4: # RGBA (Canvas often uses alpha)
|
| 146 |
+
# If background is transparent/white, alpha channel might be inverted
|
| 147 |
+
grayscale = img_array[..., 3]
|
| 148 |
+
else: # RGB -> Grayscale
|
| 149 |
grayscale = np.dot(img_array[..., :3], [0.2989, 0.5870, 0.1140])
|
| 150 |
else:
|
| 151 |
+
grayscale = img_array.copy()
|
| 152 |
|
| 153 |
+
# 3. AUTO-INVERT: Ensure white digit on black background
|
| 154 |
+
# If the average pixel value is bright (> 127), the user drew dark text on light background.
|
| 155 |
+
if np.mean(grayscale) > 127:
|
| 156 |
+
grayscale = 255.0 - grayscale
|
| 157 |
+
|
| 158 |
+
# 4. Check if the canvas is empty
|
| 159 |
+
if np.max(grayscale) < 15:
|
| 160 |
return default_output
|
| 161 |
|
| 162 |
+
# Debugging print to check what your model is actually receiving
|
| 163 |
+
print(f"Processed image shape: {grayscale.shape} | Max Val: {np.max(grayscale)} | Mean Val: {np.mean(grayscale):.2f}")
|
|
|
|
| 164 |
|
| 165 |
+
# 5. Convert to PIL, Resize, and Transform
|
| 166 |
img = Image.fromarray(grayscale.astype(np.uint8), mode="L")
|
| 167 |
+
img = img.resize((28, 28), Image.Resampling.BILINEAR)
|
| 168 |
tensor_img = transform_fn(img).unsqueeze(0).to(device)
|
| 169 |
|
| 170 |
+
# 6. Model Inference
|
| 171 |
with torch.no_grad():
|
| 172 |
outputs = model(tensor_img)
|
| 173 |
probabilities = F.softmax(outputs, dim=1)[0]
|
|
|
|
| 178 |
print(f"Prediction error: {e}")
|
| 179 |
return default_output
|
| 180 |
|
| 181 |
+
|
| 182 |
# --- GRADIO INTERFACE ---
|
| 183 |
with gr.Blocks() as demo:
|
| 184 |
gr.Markdown(
|