Scribbler310 commited on
Commit
79f3351
·
verified ·
1 Parent(s): 10b2987

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -29
app.py CHANGED
@@ -2,58 +2,114 @@ import tensorflow as tf
2
  import gradio as gr
3
  import cv2
4
  import numpy as np
 
5
 
6
  # 1. Load the model
7
  model = tf.keras.models.load_model('digit_recognizer.keras')
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def classify_digit(image):
10
  if image is None:
11
  return None
12
 
13
- # Robust check: Gradio 4.x Sketchpad returns a dictionary {'background':..., 'layers':..., 'composite':...}
14
  if isinstance(image, dict):
15
  image = image['composite']
16
-
17
- image = np.array(image)
18
 
19
- # --- PREPROCESSING ---
20
- # 1. Handle different input formats (RGBA from sketchpad, RGB from upload)
 
 
21
  if image.shape[-1] == 4:
22
- # RGBA: Composite onto white background then convert to Gray
23
  background = np.ones((image.shape[0], image.shape[1], 3), dtype=np.uint8) * 255
24
  alpha = image[:, :, 3] / 255.0
25
  for c in range(3):
26
  background[:, :, c] = alpha * image[:, :, c] + (1 - alpha) * background[:, :, c]
27
- image = cv2.cvtColor(background, cv2.COLOR_RGB2GRAY)
28
- elif len(image.shape) == 3 and image.shape[-1] == 3:
29
- # RGB: Convert to Gray
30
- image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
31
-
32
- # 2. Resize to 28x28 (Model Requirement)
33
- image = cv2.resize(image, (28, 28), interpolation=cv2.INTER_AREA)
34
-
35
- # 3. Invert Colors (Critical)
36
- # MNIST expects white digit on black background.
37
- # If image is mostly bright (white paper/canvas), invert it.
38
- if np.mean(image) > 127:
39
- image = 255 - image
40
 
41
- # 4. Normalize & Reshape
42
- image = image.reshape(1, 28, 28, 1) / 255.0
 
 
 
43
 
44
  # --- PREDICTION ---
45
- prediction = model.predict(image).flatten()
46
  return {str(i): float(prediction[i]) for i in range(10)}
47
 
48
  # --- UI SETUP ---
49
  with gr.Blocks() as demo:
50
  gr.Markdown("## Handwritten Digit Recognizer")
51
- gr.Markdown("Draw a digit (0-9) or upload a photo to test the model.")
52
 
53
  with gr.Tabs():
54
- # Tab 1: Drawing Interface
55
  with gr.Tab("Draw Digit"):
56
- # FIX: Use 'default_size' and 'colors' instead of 'thickness' and 'color'
57
  sketchpad = gr.Sketchpad(
58
  label="Draw Here",
59
  type="numpy",
@@ -61,16 +117,12 @@ with gr.Blocks() as demo:
61
  )
62
  btn_draw = gr.Button("Predict Drawing", variant="primary")
63
 
64
- # Tab 2: Upload Interface
65
  with gr.Tab("Upload Photo"):
66
- # FIX: Use 'sources=["upload", "clipboard"]' to avoid the source list error
67
  upload = gr.Image(label="Upload Image", sources=["upload", "clipboard"], type="numpy")
68
  btn_upload = gr.Button("Predict Upload", variant="primary")
69
 
70
- # Output is shared
71
  label = gr.Label(num_top_classes=3, label="Prediction")
72
 
73
- # Connect both buttons to the same function
74
  btn_draw.click(fn=classify_digit, inputs=sketchpad, outputs=label)
75
  btn_upload.click(fn=classify_digit, inputs=upload, outputs=label)
76
 
 
2
  import gradio as gr
3
  import cv2
4
  import numpy as np
5
+ import math
6
 
7
  # 1. Load the model
8
  model = tf.keras.models.load_model('digit_recognizer.keras')
9
 
10
+ def preprocess_mnist_style(image):
11
+ """
12
+ Converts a user drawing into the strict format expected by MNIST models:
13
+ - Invert colors (if needed) to get white digit on black background
14
+ - Crop to bounding box (remove empty margins)
15
+ - Resize digit to max 20x20 while preserving aspect ratio
16
+ - Center digit by center-of-mass in a 28x28 image
17
+ """
18
+ # 1. Convert to grayscale if needed
19
+ if len(image.shape) == 3:
20
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
21
+
22
+ # 2. Invert if background is light (MNIST is white-on-black)
23
+ if np.mean(image) > 127:
24
+ image = 255 - image
25
+
26
+ # 3. Find the bounding box of the digit (crop empty space)
27
+ # Find all non-zero points (pixels that are part of the digit)
28
+ coords = cv2.findNonZero(image)
29
+ if coords is None:
30
+ return image # Return original if empty
31
+
32
+ x, y, w, h = cv2.boundingRect(coords)
33
+ # Crop the digit
34
+ digit = image[y:y+h, x:x+w]
35
+
36
+ # 4. Resize to fit inside a 20x20 box (leaving 4px buffer)
37
+ # MNIST digits are 20x20 centered in 28x28
38
+ rows, cols = digit.shape
39
+ if rows > cols:
40
+ factor = 20.0 / rows
41
+ rows = 20
42
+ cols = int(round(cols * factor))
43
+ else:
44
+ factor = 20.0 / cols
45
+ cols = 20
46
+ rows = int(round(rows * factor))
47
+
48
+ # Resize using INTER_AREA for better quality downscaling
49
+ digit = cv2.resize(digit, (cols, rows), interpolation=cv2.INTER_AREA)
50
+
51
+ # 5. Paste the resized digit into the center of a black 28x28 canvas
52
+ new_image = np.zeros((28, 28), dtype=np.uint8)
53
+
54
+ # Calculate center offset
55
+ pad_x = (28 - cols) // 2
56
+ pad_y = (28 - rows) // 2
57
+
58
+ new_image[pad_y:pad_y+rows, pad_x:pad_x+cols] = digit
59
+
60
+ # 6. Center by "Center of Mass" (Refinement step used in original MNIST)
61
+ # Calculate moments to find the weighted center
62
+ moments = cv2.moments(new_image)
63
+ if moments['m00'] > 0:
64
+ cx = moments['m10'] / moments['m00']
65
+ cy = moments['m01'] / moments['m00']
66
+
67
+ # Shift to align center of mass (cx, cy) to image center (14, 14)
68
+ shift_x = 14 - cx
69
+ shift_y = 14 - cy
70
+
71
+ M = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
72
+ new_image = cv2.warpAffine(new_image, M, (28, 28))
73
+
74
+ return new_image
75
+
76
  def classify_digit(image):
77
  if image is None:
78
  return None
79
 
80
+ # Handle Gradio 4.x dictionary input
81
  if isinstance(image, dict):
82
  image = image['composite']
 
 
83
 
84
+ image = np.array(image)
85
+
86
+ # --- INPUT HANDLING ---
87
+ # Handle RGBA (Transparent)
88
  if image.shape[-1] == 4:
89
+ # Create white background
90
  background = np.ones((image.shape[0], image.shape[1], 3), dtype=np.uint8) * 255
91
  alpha = image[:, :, 3] / 255.0
92
  for c in range(3):
93
  background[:, :, c] = alpha * image[:, :, c] + (1 - alpha) * background[:, :, c]
94
+ image = background
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # --- APPLY ROBUST PREPROCESSING ---
97
+ processed_image = preprocess_mnist_style(image)
98
+
99
+ # Normalize (0 to 1) and Reshape
100
+ final_input = processed_image.reshape(1, 28, 28, 1) / 255.0
101
 
102
  # --- PREDICTION ---
103
+ prediction = model.predict(final_input).flatten()
104
  return {str(i): float(prediction[i]) for i in range(10)}
105
 
106
  # --- UI SETUP ---
107
  with gr.Blocks() as demo:
108
  gr.Markdown("## Handwritten Digit Recognizer")
109
+ gr.Markdown("Draw a digit (0-9) or upload a photo. The robust preprocessing now centers and scales your input like real MNIST data.")
110
 
111
  with gr.Tabs():
 
112
  with gr.Tab("Draw Digit"):
 
113
  sketchpad = gr.Sketchpad(
114
  label="Draw Here",
115
  type="numpy",
 
117
  )
118
  btn_draw = gr.Button("Predict Drawing", variant="primary")
119
 
 
120
  with gr.Tab("Upload Photo"):
 
121
  upload = gr.Image(label="Upload Image", sources=["upload", "clipboard"], type="numpy")
122
  btn_upload = gr.Button("Predict Upload", variant="primary")
123
 
 
124
  label = gr.Label(num_top_classes=3, label="Prediction")
125
 
 
126
  btn_draw.click(fn=classify_digit, inputs=sketchpad, outputs=label)
127
  btn_upload.click(fn=classify_digit, inputs=upload, outputs=label)
128