awais0300 commited on
Commit
2c194f5
·
verified ·
1 Parent(s): 6555f40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -17
app.py CHANGED
@@ -3,40 +3,37 @@ import tensorflow as tf
3
  import numpy as np
4
  import cv2
5
 
6
- # Load the SavedModel
7
- model = tf.keras.models.load_model("mask_mobilenet_savedmodel") # Folder path
8
 
9
- # Prediction function
10
  def predict_mask(image):
11
  try:
12
- # Convert to RGB if needed
13
- if image.shape[2] == 4: # RGBA
14
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
15
 
16
- # Resize to model input
17
  image = cv2.resize(image, (224, 224))
18
- image = image / 255.0 # Normalize
19
- image = np.expand_dims(image, axis=0) # Add batch dimension
20
 
21
- # Predict
22
- preds = model.predict(image)
23
- print("Preds:", preds) # Logs in console
24
-
25
- # Interpret prediction
26
  result = "Mask" if preds[0][0] > 0.5 else "No Mask"
27
  return result
28
 
29
  except Exception as e:
30
- print("Error:", e) # Logs in console
31
  return f"Error: {e}"
32
 
33
  # Gradio interface
34
  iface = gr.Interface(
35
  fn=predict_mask,
36
  inputs=gr.Image(type="numpy"),
37
- outputs=gr.Textbox(label="Prediction"),
38
  title="Mask Detection",
39
- description="Upload an image to check if a person is wearing a mask or not."
40
  )
41
 
42
- iface.launch()
 
 
3
  import numpy as np
4
  import cv2
5
 
6
+ # Load your SavedModel folder
7
+ model = tf.keras.models.load_model("mask_mobilenet_savedmodel")
8
 
 
9
  def predict_mask(image):
10
  try:
11
+ # Convert RGBA to RGB if needed
12
+ if image.shape[2] == 4:
13
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
14
 
15
+ # Resize and normalize
16
  image = cv2.resize(image, (224, 224))
17
+ image = image / 255.0
18
+ image = np.expand_dims(image, axis=0)
19
 
20
+ # Predict using SavedModel callable
21
+ preds = model(image, training=False).numpy()
 
 
 
22
  result = "Mask" if preds[0][0] > 0.5 else "No Mask"
23
  return result
24
 
25
  except Exception as e:
26
+ print("Error:", e)
27
  return f"Error: {e}"
28
 
29
  # Gradio interface
30
  iface = gr.Interface(
31
  fn=predict_mask,
32
  inputs=gr.Image(type="numpy"),
33
+ outputs="text",
34
  title="Mask Detection",
35
+ description="Upload an image to check if a person is wearing a mask."
36
  )
37
 
38
+ if __name__ == "__main__":
39
+ iface.launch()