Anshini commited on
Commit
4ea7f89
·
verified ·
1 Parent(s): e5ac68e

Update app1.py

Browse files
Files changed (1) hide show
  1. app1.py +5 -7
app1.py CHANGED
@@ -6,27 +6,25 @@ from PIL import Image
6
  # Load the trained model
7
  model = load_model("mnist_model.h5")
8
 
9
- # Define prediction function
10
  def predict_digit(image):
11
- # Resize and normalize
12
- image = image.convert('L').resize((28, 28)) # convert to grayscale and resize
13
  img_array = np.array(image).astype("float32") / 255.0
14
  img_array = img_array.reshape(1, 28, 28)
15
 
16
- # Predict
17
  prediction = model.predict(img_array)
18
  predicted_class = np.argmax(prediction)
19
  confidence = float(np.max(prediction))
20
 
21
  return f"Prediction: {predicted_class} (Confidence: {confidence:.2f})"
22
 
23
- # Define Gradio Interface
24
  interface = gr.Interface(
25
  fn=predict_digit,
26
- inputs=gr.Image(type="pil", shape=(200, 200), label="Upload a Digit Image"),
27
  outputs=gr.Textbox(label="Prediction"),
28
  title="Handwritten Digit Recognition",
29
- description="Upload a handwritten digit image (0–9) to classify it using a neural network trained on the MNIST dataset."
30
  )
31
 
32
  interface.launch()
 
6
  # Load the trained model
7
  model = load_model("mnist_model.h5")
8
 
9
+ # Prediction function
10
  def predict_digit(image):
11
+ image = image.convert('L').resize((28, 28))
 
12
  img_array = np.array(image).astype("float32") / 255.0
13
  img_array = img_array.reshape(1, 28, 28)
14
 
 
15
  prediction = model.predict(img_array)
16
  predicted_class = np.argmax(prediction)
17
  confidence = float(np.max(prediction))
18
 
19
  return f"Prediction: {predicted_class} (Confidence: {confidence:.2f})"
20
 
21
+ # Gradio Interface (no shape argument)
22
  interface = gr.Interface(
23
  fn=predict_digit,
24
+ inputs=gr.Image(type="pil", label="Upload a Digit Image"),
25
  outputs=gr.Textbox(label="Prediction"),
26
  title="Handwritten Digit Recognition",
27
+ description="Upload a handwritten digit image (0–9) to classify it using a model trained on the MNIST dataset."
28
  )
29
 
30
  interface.launch()