aryamanpathak commited on
Commit
fa16b96
·
verified ·
1 Parent(s): 1690b06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -34
app.py CHANGED
@@ -1,44 +1,71 @@
1
- import gradio as gr
2
  import numpy as np
3
  import tensorflow as tf
4
  from tensorflow.keras.models import load_model
5
- import cv2
 
6
 
7
- # Define the IOU metric exactly as provided
 
 
 
 
8
  def iou_metric(y_true, y_pred):
9
  y_pred = tf.cast(y_pred > 0.5, tf.float32)
10
  intersection = tf.reduce_sum(y_true * y_pred)
11
  union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
12
  return intersection / (union + 1e-7)
13
 
14
- # Load the trained U-Net model
15
- model = load_model("unet_mask_segmentation.h5", custom_objects={'iou_metric': iou_metric})
16
-
17
- # Preprocess the uploaded image for prediction
18
- def preprocess_image(image, target_size=(256, 256)):
19
- image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Convert RGB -> BGR for OpenCV ops
20
- image_resized = cv2.resize(image_bgr, target_size)
21
- image_resized = image_resized / 255.0 # Normalize
22
- return np.expand_dims(image_resized, axis=0) # Add batch dimension
23
-
24
- # Predict and return the segmented mask
25
- def segment_image(input_image):
26
- original_size = (input_image.shape[1], input_image.shape[0]) # (width, height)
27
- preprocessed = preprocess_image(input_image)
28
- pred_mask = model.predict(preprocessed)[0] # Remove batch dimension
29
- binary_mask = (pred_mask > 0.5).astype(np.uint8) * 255
30
- binary_mask_resized = cv2.resize(binary_mask, original_size) # Resize mask to original
31
- binary_mask_rgb = cv2.cvtColor(binary_mask_resized, cv2.COLOR_GRAY2RGB) # Convert to 3-channel RGB
32
- return binary_mask_rgb
33
-
34
- # Gradio Interface
35
- interface = gr.Interface(
36
- fn=segment_image,
37
- inputs=gr.Image(type="numpy", label="Upload Image"),
38
- outputs=gr.Image(type="numpy", label="Segmented Mask"),
39
- title="Image Segmentation with U-Net",
40
- description="Upload an image to see the segmentation mask predicted by the U-Net model."
41
- )
42
-
43
- if __name__ == "__main__":
44
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
  import numpy as np
3
  import tensorflow as tf
4
  from tensorflow.keras.models import load_model
5
+ import matplotlib.pyplot as plt
6
+ import os
7
 
8
+ # Define image dimensions
9
+ IMG_HEIGHT = 256
10
+ IMG_WIDTH = 256
11
+
12
+ # Load trained model
13
  def iou_metric(y_true, y_pred):
14
  y_pred = tf.cast(y_pred > 0.5, tf.float32)
15
  intersection = tf.reduce_sum(y_true * y_pred)
16
  union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
17
  return intersection / (union + 1e-7)
18
 
19
+ def dice_metric(y_true, y_pred):
20
+ y_pred = tf.cast(y_pred > 0.5, tf.float32)
21
+ intersection = tf.reduce_sum(y_true * y_pred)
22
+ return (2. * intersection) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + 1e-7)
23
+
24
+ model = load_model("unet_mask_segmentation.h5", custom_objects={'iou_metric': iou_metric, 'dice_metric': dice_metric})
25
+
26
+ # Function to preprocess input image
27
+ def preprocess_image(image_path):
28
+ img = cv2.imread(image_path)
29
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
30
+ img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT)) / 255.0 # Resize and normalize
31
+ return np.expand_dims(img, axis=0) # Add batch dimension
32
+
33
+ # Function to make predictions
34
+ def predict_mask(image_path):
35
+ img = preprocess_image(image_path)
36
+ mask_pred = model.predict(img)[0] # Remove batch dimension
37
+ mask_pred = (mask_pred > 0.5).astype(np.uint8) # Threshold mask
38
+ return mask_pred
39
+
40
+ # Function to visualize results
41
+ def visualize_results(image_path, mask_pred):
42
+ img = cv2.imread(image_path)
43
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert to RGB
44
+ img_resized = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
45
+
46
+ fig, ax = plt.subplots(1, 2, figsize=(12, 4))
47
+ ax[0].imshow(img_resized)
48
+ ax[0].set_title("Original Image")
49
+ ax[0].axis("off")
50
+
51
+ ax[1].imshow(mask_pred, cmap="gray")
52
+ ax[1].set_title("Predicted Mask")
53
+ ax[1].axis("off")
54
+
55
+ # overlay = img_resized.copy()
56
+ # overlay[mask_pred == 1] = [255, 0, 0] # Apply red overlay on mask
57
+ # ax[2].imshow(overlay)
58
+ # ax[2].set_title("Overlay")
59
+ # ax[2].axis("off")
60
+
61
+ plt.show()
62
+
63
+ # Test on an image
64
+ TEST_IMAGE = "/content/MSFD/1/face_crop/000020_1.jpg" # Change this to your test image path
65
+ if os.path.exists(TEST_IMAGE):
66
+ mask_prediction = predict_mask(TEST_IMAGE)
67
+ visualize_results(TEST_IMAGE, mask_prediction)
68
+ else:
69
+ print(f"Test image '{TEST_IMAGE}' not found. Please check the path.")
70
+
71
+ isse bana de app.py