Scribbler310 commited on
Commit
38fdc77
·
verified ·
1 Parent(s): f176231

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -23
app.py CHANGED
@@ -1,47 +1,66 @@
1
  import tensorflow as tf
2
  import gradio as gr
3
- import numpy as np
4
  import cv2
 
5
 
6
- # 1. Load the trained model
7
  model = tf.keras.models.load_model('digit_recognizer.keras')
8
 
9
- # 2. Define the classification function
10
  def classify_digit(image):
 
11
  if image is None:
12
  return None
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Preprocessing to match MNIST data format
15
- # Convert to grayscale if it isn't already
16
- if len(image.shape) == 3:
17
- image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
18
-
19
- # Resize the image to 28x28 pixels
20
- image = cv2.resize(image, (28, 28))
21
-
22
- # Reshape to (1, 28, 28, 1) to match model input shape
23
- # The '1' indicates a batch size of 1
 
 
 
24
  image = image.reshape(1, 28, 28, 1)
25
 
26
- # Normalize pixel values (0 to 1) just like in the training notebook
27
  image = image / 255.0
28
-
29
- # Predict
30
  prediction = model.predict(image).flatten()
31
-
32
- # Return dictionary for Gradio Label output
33
  return {str(i): float(prediction[i]) for i in range(10)}
34
 
35
- # 3. Build the Gradio Interface
36
- # We use Sketchpad so users can draw the digit
37
  interface = gr.Interface(
38
  fn=classify_digit,
39
- inputs=gr.Sketchpad(label="Draw a Digit"),
 
 
 
 
 
 
 
40
  outputs=gr.Label(num_top_classes=3),
41
  title="Handwritten Digit Recognizer",
42
- description="Draw a digit (0-9) on the canvas to see if the Neural Network recognizes it."
43
  )
44
 
45
- # 4. Launch
46
  if __name__ == "__main__":
47
  interface.launch()
 
1
  import tensorflow as tf
2
  import gradio as gr
 
3
  import cv2
4
+ import numpy as np
5
 
6
+ # 1. Load your saved model
7
  model = tf.keras.models.load_model('digit_recognizer.keras')
8
 
 
9
  def classify_digit(image):
10
+ # Error handling: if no image is provided
11
  if image is None:
12
  return None
13
+
14
+ # --- PREPROCESSING ---
15
+ # Convert to numpy array if it isn't already
16
+ image = np.array(image)
17
+
18
+ # 1. Handle Color Channels
19
+ # If image has 4 channels (RGBA) from sketchpad, convert to Gray
20
+ if image.shape[-1] == 4:
21
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2GRAY)
22
+ # If image has 3 channels (RGB) from upload, convert to Gray
23
+ elif image.shape[-1] == 3:
24
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
25
 
26
+ # 2. Resize to 28x28
27
+ # We use INTER_AREA for shrinking which preserves details better than default
28
+ image = cv2.resize(image, (28, 28), interpolation=cv2.INTER_AREA)
29
+
30
+ # 3. Invert Colors (Critical Step)
31
+ # MNIST models expect White Text on Black Background.
32
+ # If the image is mostly bright (like white paper), we must invert it.
33
+ avg_brightness = np.mean(image)
34
+ if avg_brightness > 127: # If the image is mostly white/light
35
+ image = 255 - image # Invert to black background
36
+
37
+ # 4. Reshape for Model
38
+ # (1 sample, 28 height, 28 width, 1 channel)
39
  image = image.reshape(1, 28, 28, 1)
40
 
41
+ # 5. Normalize (0 to 1)
42
  image = image / 255.0
43
+
44
+ # --- PREDICTION ---
45
  prediction = model.predict(image).flatten()
 
 
46
  return {str(i): float(prediction[i]) for i in range(10)}
47
 
48
+ # --- GRADIO INTERFACE ---
49
+ # sources=["upload", "canvas"] enables both file upload and drawing
50
  interface = gr.Interface(
51
  fn=classify_digit,
52
+ inputs=gr.Image(
53
+ type="numpy",
54
+ label="Draw or Upload Digit",
55
+ image_mode="L", # "L" attempts to convert to grayscale immediately
56
+ sources=["upload", "canvas"],
57
+ height=400,
58
+ width=400
59
+ ),
60
  outputs=gr.Label(num_top_classes=3),
61
  title="Handwritten Digit Recognizer",
62
+ description="Draw a digit on the canvas OR upload a photo of a digit. The model will guess what it is."
63
  )
64
 
 
65
  if __name__ == "__main__":
66
  interface.launch()