Alen Hovhannisians commited on
Commit
187aaea
·
1 Parent(s): bf2aba1
Files changed (2) hide show
  1. app.py +22 -25
  2. requirements.txt +1 -1
app.py CHANGED
@@ -5,36 +5,35 @@ from PIL import Image
5
 
6
  MODEL_PATH = "mnist_cnn.h5"
7
 
8
- # Load model (CPU only – CUDA warnings are normal)
9
  model = tf.keras.models.load_model(MODEL_PATH)
10
 
11
 
12
  def preprocess(image):
13
- # image arrives as numpy array (H, W, C)
14
  if image is None:
15
  return None
16
 
17
- # Convert to PIL
18
- image = Image.fromarray(image.astype("uint8"))
 
19
 
20
  # Convert to grayscale
21
  image = image.convert("L")
22
 
23
- # Resize to MNIST
24
  image = image.resize((28, 28))
25
 
26
  img = np.array(image).astype("float32")
27
 
28
- # Invert colors (MNIST = white digit on black)
29
  img = 255 - img
30
 
31
- # Threshold to remove gray noise
32
  img[img < 40] = 0
33
 
34
  # Normalize
35
  img /= 255.0
36
 
37
- # Add channel & batch dimensions
38
  img = np.expand_dims(img, axis=-1)
39
  img = np.expand_dims(img, axis=0)
40
 
@@ -47,22 +46,20 @@ def predict(image):
47
  return {str(i): float(preds[i]) for i in range(10)}
48
 
49
 
50
- with gr.Blocks() as demo:
51
- gr.Markdown("# ✍️ MNIST Handwritten Digit Classifier")
52
- gr.Markdown("Draw a **single large digit**, centered, like MNIST.")
53
-
54
- canvas = gr.Image(
55
- label="Draw a digit",
56
- tool="sketch", # correct canvas for Gradio 4.x
57
- image_mode="RGB",
58
- height=280,
59
- width=280
60
- )
61
-
62
- output = gr.Label(num_top_classes=3)
63
-
64
- btn = gr.Button("Predict")
65
-
66
- btn.click(fn=predict, inputs=canvas, outputs=output)
67
 
68
  demo.launch()
 
5
 
6
  MODEL_PATH = "mnist_cnn.h5"
7
 
 
8
  model = tf.keras.models.load_model(MODEL_PATH)
9
 
10
 
11
  def preprocess(image):
 
12
  if image is None:
13
  return None
14
 
15
+ # Ensure PIL
16
+ if isinstance(image, np.ndarray):
17
+ image = Image.fromarray(image.astype("uint8"))
18
 
19
  # Convert to grayscale
20
  image = image.convert("L")
21
 
22
+ # Resize
23
  image = image.resize((28, 28))
24
 
25
  img = np.array(image).astype("float32")
26
 
27
+ # Invert colors (MNIST style)
28
  img = 255 - img
29
 
30
+ # Remove background noise
31
  img[img < 40] = 0
32
 
33
  # Normalize
34
  img /= 255.0
35
 
36
+ # Add dims
37
  img = np.expand_dims(img, axis=-1)
38
  img = np.expand_dims(img, axis=0)
39
 
 
46
  return {str(i): float(preds[i]) for i in range(10)}
47
 
48
 
49
+ demo = gr.Interface(
50
+ fn=predict,
51
+ inputs=gr.Image(
52
+ label="Upload a digit image (white digit on dark background)"
53
+ ),
54
+ outputs=gr.Label(num_top_classes=3),
55
+ title="MNIST Handwritten Digit Classifier",
56
+ description=(
57
+ "Upload an image of a handwritten digit.\n\n"
58
+ "Tips:\n"
59
+ "- Dark background\n"
60
+ "- Light digit\n"
61
+ "- Centered and large"
62
+ ),
63
+ )
 
 
64
 
65
  demo.launch()
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  tensorflow>=2.11
2
  numpy
3
  pillow
4
- gradio>=4.0
 
1
  tensorflow>=2.11
2
  numpy
3
  pillow
4
+ gradio