Scribbler310 commited on
Commit
e932bc5
·
verified ·
1 Parent(s): 38fdc77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -45
app.py CHANGED
@@ -3,64 +3,72 @@ import gradio as gr
3
  import cv2
4
  import numpy as np
5
 
6
- # 1. Load your saved model
7
  model = tf.keras.models.load_model('digit_recognizer.keras')
8
 
9
  def classify_digit(image):
10
- # Error handling: if no image is provided
11
- if image is None:
12
  return None
13
-
14
- # --- PREPROCESSING ---
15
- # Convert to numpy array if it isn't already
 
 
16
  image = np.array(image)
17
-
18
- # 1. Handle Color Channels
19
- # If image has 4 channels (RGBA) from sketchpad, convert to Gray
20
  if image.shape[-1] == 4:
21
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2GRAY)
22
- # If image has 3 channels (RGB) from upload, convert to Gray
23
- elif image.shape[-1] == 3:
 
 
 
 
 
24
  image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
25
 
26
- # 2. Resize to 28x28
27
- # We use INTER_AREA for shrinking which preserves details better than default
28
  image = cv2.resize(image, (28, 28), interpolation=cv2.INTER_AREA)
29
-
30
- # 3. Invert Colors (Critical Step)
31
- # MNIST models expect White Text on Black Background.
32
- # If the image is mostly bright (like white paper), we must invert it.
33
- avg_brightness = np.mean(image)
34
- if avg_brightness > 127: # If the image is mostly white/light
35
- image = 255 - image # Invert to black background
36
-
37
- # 4. Reshape for Model
38
- # (1 sample, 28 height, 28 width, 1 channel)
39
- image = image.reshape(1, 28, 28, 1)
40
 
41
- # 5. Normalize (0 to 1)
42
- image = image / 255.0
43
-
 
 
 
 
 
 
44
  # --- PREDICTION ---
45
  prediction = model.predict(image).flatten()
46
  return {str(i): float(prediction[i]) for i in range(10)}
47
 
48
- # --- GRADIO INTERFACE ---
49
- # sources=["upload", "canvas"] enables both file upload and drawing
50
- interface = gr.Interface(
51
- fn=classify_digit,
52
- inputs=gr.Image(
53
- type="numpy",
54
- label="Draw or Upload Digit",
55
- image_mode="L", # "L" attempts to convert to grayscale immediately
56
- sources=["upload", "canvas"],
57
- height=400,
58
- width=400
59
- ),
60
- outputs=gr.Label(num_top_classes=3),
61
- title="Handwritten Digit Recognizer",
62
- description="Draw a digit on the canvas OR upload a photo of a digit. The model will guess what it is."
63
- )
 
 
 
 
 
 
 
 
64
 
65
  if __name__ == "__main__":
66
- interface.launch()
 
3
  import cv2
4
  import numpy as np
5
 
6
+ # 1. Load the model
7
  model = tf.keras.models.load_model('digit_recognizer.keras')
8
 
9
  def classify_digit(image):
10
+ if image is None:
 
11
  return None
12
+
13
+ # Robust check: Gradio 4.x Sketchpad might return a dictionary
14
+ if isinstance(image, dict):
15
+ image = image['composite']
16
+
17
  image = np.array(image)
18
+
19
+ # --- PREPROCESSING ---
20
+ # 1. Handle different input formats (RGBA from sketchpad, RGB from upload)
21
  if image.shape[-1] == 4:
22
+ # RGBA: Composite onto white background then convert to Gray
23
+ background = np.ones((image.shape[0], image.shape[1], 3), dtype=np.uint8) * 255
24
+ alpha = image[:, :, 3] / 255.0
25
+ for c in range(3):
26
+ background[:, :, c] = alpha * image[:, :, c] + (1 - alpha) * background[:, :, c]
27
+ image = cv2.cvtColor(background, cv2.COLOR_RGB2GRAY)
28
+ elif len(image.shape) == 3 and image.shape[-1] == 3:
29
+ # RGB: Convert to Gray
30
  image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
31
 
32
+ # 2. Resize to 28x28 (Model Requirement)
 
33
  image = cv2.resize(image, (28, 28), interpolation=cv2.INTER_AREA)
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # 3. Invert Colors (Critical)
36
+ # MNIST expects white digit on black background.
37
+ # If image is mostly bright (white paper/canvas), invert it.
38
+ if np.mean(image) > 127:
39
+ image = 255 - image
40
+
41
+ # 4. Normalize & Reshape
42
+ image = image.reshape(1, 28, 28, 1) / 255.0
43
+
44
  # --- PREDICTION ---
45
  prediction = model.predict(image).flatten()
46
  return {str(i): float(prediction[i]) for i in range(10)}
47
 
48
+ # --- UI SETUP ---
49
+ # We use gr.Blocks to create a custom layout with Tabs
50
+ with gr.Blocks() as demo:
51
+ gr.Markdown("## Handwritten Digit Recognizer")
52
+ gr.Markdown("Draw a digit (0-9) or upload a photo to test the model.")
53
+
54
+ with gr.Tabs():
55
+ # Tab 1: Drawing Interface
56
+ with gr.Tab("Draw Digit"):
57
+ sketchpad = gr.Sketchpad(label="Draw Here", type="numpy", brush=gr.Brush(color="#000000", thickness=20))
58
+ btn_draw = gr.Button("Predict Drawing", variant="primary")
59
+
60
+ # Tab 2: Upload Interface
61
+ with gr.Tab("Upload Photo"):
62
+ # sources=["upload", "clipboard"] fixes your specific error
63
+ upload = gr.Image(label="Upload Image", sources=["upload", "clipboard"], type="numpy")
64
+ btn_upload = gr.Button("Predict Upload", variant="primary")
65
+
66
+ # Output is shared
67
+ label = gr.Label(num_top_classes=3, label="Prediction")
68
+
69
+ # Connect both buttons to the same function
70
+ btn_draw.click(fn=classify_digit, inputs=sketchpad, outputs=label)
71
+ btn_upload.click(fn=classify_digit, inputs=upload, outputs=label)
72
 
73
  if __name__ == "__main__":
74
+ demo.launch()