rohan poudel commited on
Commit
6ab34b0
·
1 Parent(s): 4b2e449

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -25,10 +25,10 @@ def preprocess_image(image):
25
  img = np.expand_dims(img, axis=0)
26
  return img
27
 
28
- # Define the function to make predictions on an image
29
 
 
30
 
31
- def predict(image):
32
  # Load the TFLite model
33
  interpreter = tf.lite.Interpreter(model_path=MODEL_PATH)
34
  interpreter.allocate_tensors()
@@ -38,7 +38,12 @@ def predict(image):
38
  output_details = interpreter.get_output_details()
39
 
40
  # Preprocess image
41
- img = preprocess_image(image)
 
 
 
 
 
42
 
43
  # Set input tensor
44
  interpreter.set_tensor(input_details[0]['index'], img)
@@ -65,12 +70,10 @@ def predict(image):
65
 
66
 
67
  # Define the Gradio interface
68
- inputs = gr.inputs.ImageUploader(
69
- label="Select an image or capture using camera")
70
  outputs = gr.outputs.Label(num_top_classes=5)
71
  interface = gr.Interface(fn=predict, inputs=inputs,
72
  outputs=outputs, capture_session=True)
73
 
74
-
75
  # Run the interface
76
  interface.launch()
 
25
  img = np.expand_dims(img, axis=0)
26
  return img
27
 
 
28
 
29
+ # Define the function to make predictions on an image
30
 
31
+ def predict(image_path_or_pil_image):
32
  # Load the TFLite model
33
  interpreter = tf.lite.Interpreter(model_path=MODEL_PATH)
34
  interpreter.allocate_tensors()
 
38
  output_details = interpreter.get_output_details()
39
 
40
  # Preprocess image
41
+ if isinstance(image_path_or_pil_image, str):
42
+ img = cv2.imread(image_path_or_pil_image)
43
+ img = preprocess_image(img, input_details[0]['shape'][1:3])
44
+ else:
45
+ img = np.array(image_path_or_pil_image)
46
+ img = preprocess_image(img)
47
 
48
  # Set input tensor
49
  interpreter.set_tensor(input_details[0]['index'], img)
 
70
 
71
 
72
  # Define the Gradio interface
73
+ inputs = gr.inputs.Image()
 
74
  outputs = gr.outputs.Label(num_top_classes=5)
75
  interface = gr.Interface(fn=predict, inputs=inputs,
76
  outputs=outputs, capture_session=True)
77
 
 
78
  # Run the interface
79
  interface.launch()