aryamanpathak commited on
Commit
1177d0b
·
verified ·
1 Parent(s): d9e65dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -51
app.py CHANGED
@@ -1,15 +1,13 @@
 
1
  import cv2
2
  import numpy as np
3
  import tensorflow as tf
4
  from tensorflow.keras.models import load_model
5
- import matplotlib.pyplot as plt
6
- import os
7
 
8
- # Define image dimensions
9
- IMG_HEIGHT = 256
10
- IMG_WIDTH = 256
11
 
12
- # Load trained model
13
  def iou_metric(y_true, y_pred):
14
  y_pred = tf.cast(y_pred > 0.5, tf.float32)
15
  intersection = tf.reduce_sum(y_true * y_pred)
@@ -21,50 +19,41 @@ def dice_metric(y_true, y_pred):
21
  intersection = tf.reduce_sum(y_true * y_pred)
22
  return (2. * intersection) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + 1e-7)
23
 
 
24
  model = load_model("unet_mask_segmentation.h5", custom_objects={'iou_metric': iou_metric, 'dice_metric': dice_metric})
25
 
26
- # Function to preprocess input image
27
- def preprocess_image(image_path):
28
- img = cv2.imread(image_path)
29
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
30
- img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT)) / 255.0 # Resize and normalize
31
- return np.expand_dims(img, axis=0) # Add batch dimension
32
-
33
- # Function to make predictions
34
- def predict_mask(image_path):
35
- img = preprocess_image(image_path)
36
- mask_pred = model.predict(img)[0] # Remove batch dimension
37
- mask_pred = (mask_pred > 0.5).astype(np.uint8) # Threshold mask
38
- return mask_pred
39
-
40
- # Function to visualize results
41
- def visualize_results(image_path, mask_pred):
42
- img = cv2.imread(image_path)
43
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert to RGB
44
- img_resized = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
45
-
46
- fig, ax = plt.subplots(1, 2, figsize=(12, 4))
47
- ax[0].imshow(img_resized)
48
- ax[0].set_title("Original Image")
49
- ax[0].axis("off")
50
-
51
- ax[1].imshow(mask_pred, cmap="gray")
52
- ax[1].set_title("Predicted Mask")
53
- ax[1].axis("off")
54
-
55
- # overlay = img_resized.copy()
56
- # overlay[mask_pred == 1] = [255, 0, 0] # Apply red overlay on mask
57
- # ax[2].imshow(overlay)
58
- # ax[2].set_title("Overlay")
59
- # ax[2].axis("off")
60
-
61
- plt.show()
62
-
63
- # Test on an image
64
- TEST_IMAGE = "/content/MSFD/1/face_crop/000020_1.jpg" # Change this to your test image path
65
- if os.path.exists(TEST_IMAGE):
66
- mask_prediction = predict_mask(TEST_IMAGE)
67
- visualize_results(TEST_IMAGE, mask_prediction)
68
- else:
69
- print(f"Test image '{TEST_IMAGE}' not found. Please check the path.")
70
-
 
1
+ import gradio as gr
2
  import cv2
3
  import numpy as np
4
  import tensorflow as tf
5
  from tensorflow.keras.models import load_model
 
 
6
 
7
+ # Constants
8
+ IMG_HEIGHT, IMG_WIDTH = 256, 256
 
9
 
10
+ # Custom metrics
11
  def iou_metric(y_true, y_pred):
12
  y_pred = tf.cast(y_pred > 0.5, tf.float32)
13
  intersection = tf.reduce_sum(y_true * y_pred)
 
19
  intersection = tf.reduce_sum(y_true * y_pred)
20
  return (2. * intersection) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + 1e-7)
21
 
22
+ # Load model
23
  model = load_model("unet_mask_segmentation.h5", custom_objects={'iou_metric': iou_metric, 'dice_metric': dice_metric})
24
 
25
+ # Predict mask function
26
+ def predict(image):
27
+ # Convert PIL to NumPy array
28
+ img_np = np.array(image)
29
+ img_resized = cv2.resize(img_np, (IMG_WIDTH, IMG_HEIGHT)) / 255.0
30
+ img_input = np.expand_dims(img_resized, axis=0)
31
+
32
+ # Predict mask
33
+ mask_pred = model.predict(img_input)[0]
34
+ mask_binary = (mask_pred > 0.5).astype(np.uint8) * 255
35
+
36
+ # Create overlay
37
+ overlay = img_resized.copy()
38
+ overlay[mask_binary == 255] = [1.0, 0, 0] # Red overlay in normalized scale
39
+
40
+ # Convert overlay and mask to displayable format
41
+ overlay_display = (overlay * 255).astype(np.uint8)
42
+ mask_display = mask_binary.astype(np.uint8)
43
+
44
+ return overlay_display, mask_display
45
+
46
+ # Gradio Interface
47
+ interface = gr.Interface(
48
+ fn=predict,
49
+ inputs=gr.Image(type="pil"),
50
+ outputs=[
51
+ gr.Image(type="numpy", label="Overlay"),
52
+ gr.Image(type="numpy", label="Predicted Mask")
53
+ ],
54
+ title="Face Mask Segmentation",
55
+ description="Upload a face image to get the predicted segmentation mask using U-Net."
56
+ )
57
+
58
+ if __name__ == "__main__":
59
+ interface.launch()