HarnithaS commited on
Commit
3ea7ff1
·
1 Parent(s): 2fabd6e

added prediction code

Browse files
Files changed (1) hide show
  1. app.py +8 -12
app.py CHANGED
@@ -26,9 +26,8 @@ def preprocess_image(drawing, size=(28, 28)):
26
  img_array = np.expand_dims(img_array, axis=-1)
27
  return img_array
28
 
29
- def preprocess_and_predict(img):
30
  model = load_model("mnist_cnn_model.h5")
31
- image = np.array(img) / 255.0
32
  # Expand dimensions to match the input shape expected by the model
33
  image = np.expand_dims(image, axis=0)
34
  # Reshape to match the input shape expected by the model
@@ -45,28 +44,25 @@ def main():
45
  # Create a drawing canvas
46
  drawing = st_canvas(
47
  fill_color="rgb(0, 0, 0)", # Background color of the canvas
48
- stroke_width=10, # Stroke width
49
  stroke_color="rgb(255, 255, 255)", # Stroke color
50
  background_color="#000000", # Background color of the canvas component
51
- height=150, # Height of the canvas
52
- width=150, # Width of the canvas
53
  drawing_mode="freedraw", # Drawing mode: "freedraw" or "transform"
54
  key="canvas",
55
  )
 
56
 
57
  # Check if the user has drawn anything
58
- if drawing is not None:
59
- st.image(drawing.image_data)
60
 
61
  # Preprocess the drawn image
62
  processed_image = preprocess_image(drawing.image_data)
63
- st.write("Processed Image Shape:", processed_image.shape)
64
-
65
- # Save the processed image
66
- np.save("processed_image.npy", processed_image)
67
- print(processed_image.shape)
68
  digit_class = preprocess_and_predict(processed_image)
 
69
  st.success(digit_class)
 
70
 
71
  if __name__ == "__main__":
72
  main()
 
26
  img_array = np.expand_dims(img_array, axis=-1)
27
  return img_array
28
 
29
+ def preprocess_and_predict(image):
30
  model = load_model("mnist_cnn_model.h5")
 
31
  # Expand dimensions to match the input shape expected by the model
32
  image = np.expand_dims(image, axis=0)
33
  # Reshape to match the input shape expected by the model
 
44
  # Create a drawing canvas
45
  drawing = st_canvas(
46
  fill_color="rgb(0, 0, 0)", # Background color of the canvas
47
+ stroke_width=4, # Stroke width
48
  stroke_color="rgb(255, 255, 255)", # Stroke color
49
  background_color="#000000", # Background color of the canvas component
50
+ height=168, # Height of the canvas
51
+ width=168, # Width of the canvas
52
  drawing_mode="freedraw", # Drawing mode: "freedraw" or "transform"
53
  key="canvas",
54
  )
55
+ predict = st.button('Predict digit')
56
 
57
  # Check if the user has drawn anything
58
+ if predict is True:
 
59
 
60
  # Preprocess the drawn image
61
  processed_image = preprocess_image(drawing.image_data)
 
 
 
 
 
62
  digit_class = preprocess_and_predict(processed_image)
63
+ st.title("Predicted Digit:")
64
  st.success(digit_class)
65
+ predict = False
66
 
67
  if __name__ == "__main__":
68
  main()