Josephus67 commited on
Commit
b5efc49
·
verified ·
1 Parent(s): ce9d2d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -12
app.py CHANGED
@@ -3,26 +3,48 @@ import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
 
6
- # Load MNIST model
 
 
 
7
  model = tf.keras.models.load_model("mnist_model.h5")
8
 
 
 
 
9
  def predict(image):
10
- # Ensure PIL image, convert to grayscale, resize
 
 
 
 
11
  image = image.convert("L").resize((28, 28))
12
-
13
- # Normalize
14
- image = np.array(image) / 255.0
15
- image = image.reshape(1, 28, 28, 1)
16
 
17
  # Predict
18
- prediction = model.predict(image)
19
- return str(np.argmax(prediction))
 
 
 
 
 
 
 
20
 
21
- # Gradio interface (upload image)
 
 
22
  iface = gr.Interface(
23
  fn=predict,
24
- inputs=gr.Image(type="pil", image_mode="L"), # ← removed shape
25
- outputs="label"
 
 
26
  )
27
 
28
- iface.launch()
 
 
3
  import numpy as np
4
  from PIL import Image
5
 
6
+ # =========================
7
+ # Load trained model
8
+ # =========================
9
+ # Make sure you've trained and saved it as best_model.h5 in your notebook
10
  model = tf.keras.models.load_model("mnist_model.h5")
11
 
12
+ # =========================
13
+ # Prediction function
14
+ # =========================
15
  def predict(image):
16
+ """
17
+ Takes a PIL image, preprocesses it (grayscale + resize),
18
+ runs prediction using trained model, and returns predicted digit.
19
+ """
20
+ # Convert to grayscale + resize
21
  image = image.convert("L").resize((28, 28))
22
+
23
+ # Convert to numpy and normalize
24
+ img_array = np.array(image) / 255.0
25
+ img_array = img_array.reshape(1, 28, 28, 1) # batch shape
26
 
27
  # Predict
28
+ prediction = model.predict(img_array)
29
+ predicted_class = np.argmax(prediction, axis=1)[0]
30
+
31
+ # Also return top-3 predictions with probabilities
32
+ top3_indices = prediction[0].argsort()[-3:][::-1]
33
+ top3_probs = prediction[0][top3_indices]
34
+
35
+ result = {str(d): float(p) for d, p in zip(top3_indices, top3_probs)}
36
+ return result
37
 
38
+ # =========================
39
+ # Gradio interface
40
+ # =========================
41
  iface = gr.Interface(
42
  fn=predict,
43
+ inputs=gr.Image(type="pil", image_mode="L"),
44
+ outputs=gr.Label(num_top_classes=3), # show top 3 predictions
45
+ title="MNIST Digit Classifier",
46
+ description="Upload a handwritten digit (0–9) image. The model will predict the digit."
47
  )
48
 
49
+ if __name__ == "__main__":
50
+ iface.launch()