aryamanpathak commited on
Commit
1690b06
·
verified ·
1 Parent(s): 34e632b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -4,33 +4,34 @@ import tensorflow as tf
4
  from tensorflow.keras.models import load_model
5
  import cv2
6
 
7
- # Custom metric (if any, otherwise comment this out)
8
  def iou_metric(y_true, y_pred):
9
- y_true = tf.cast(y_true > 0.5, tf.float32)
10
  y_pred = tf.cast(y_pred > 0.5, tf.float32)
11
  intersection = tf.reduce_sum(y_true * y_pred)
12
  union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
13
  return intersection / (union + 1e-7)
14
 
15
- # Load the model
16
  model = load_model("unet_mask_segmentation.h5", custom_objects={'iou_metric': iou_metric})
17
 
18
- # Preprocess image
19
  def preprocess_image(image, target_size=(256, 256)):
20
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
21
- image = cv2.resize(image, target_size)
22
- image = image / 255.0
23
- return np.expand_dims(image, axis=0)
24
 
25
- # Predict mask
26
  def segment_image(input_image):
 
27
  preprocessed = preprocess_image(input_image)
28
- pred_mask = model.predict(preprocessed)[0]
29
  binary_mask = (pred_mask > 0.5).astype(np.uint8) * 255
30
- binary_mask = cv2.resize(binary_mask, (input_image.shape[1], input_image.shape[0]))
31
- return binary_mask
 
32
 
33
- # Gradio interface
34
  interface = gr.Interface(
35
  fn=segment_image,
36
  inputs=gr.Image(type="numpy", label="Upload Image"),
 
4
  from tensorflow.keras.models import load_model
5
  import cv2
6
 
7
+ # Define the IOU metric exactly as provided
8
  def iou_metric(y_true, y_pred):
 
9
  y_pred = tf.cast(y_pred > 0.5, tf.float32)
10
  intersection = tf.reduce_sum(y_true * y_pred)
11
  union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
12
  return intersection / (union + 1e-7)
13
 
14
+ # Load the trained U-Net model
15
  model = load_model("unet_mask_segmentation.h5", custom_objects={'iou_metric': iou_metric})
16
 
17
+ # Preprocess the uploaded image for prediction
18
  def preprocess_image(image, target_size=(256, 256)):
19
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Convert RGB -> BGR for OpenCV ops
20
+ image_resized = cv2.resize(image_bgr, target_size)
21
+ image_resized = image_resized / 255.0 # Normalize
22
+ return np.expand_dims(image_resized, axis=0) # Add batch dimension
23
 
24
+ # Predict and return the segmented mask
25
  def segment_image(input_image):
26
+ original_size = (input_image.shape[1], input_image.shape[0]) # (width, height)
27
  preprocessed = preprocess_image(input_image)
28
+ pred_mask = model.predict(preprocessed)[0] # Remove batch dimension
29
  binary_mask = (pred_mask > 0.5).astype(np.uint8) * 255
30
+ binary_mask_resized = cv2.resize(binary_mask, original_size) # Resize mask to original
31
+ binary_mask_rgb = cv2.cvtColor(binary_mask_resized, cv2.COLOR_GRAY2RGB) # Convert to 3-channel RGB
32
+ return binary_mask_rgb
33
 
34
+ # Gradio Interface
35
  interface = gr.Interface(
36
  fn=segment_image,
37
  inputs=gr.Image(type="numpy", label="Upload Image"),