aryamanpathak commited on
Commit
c89c5ca
·
verified ·
1 Parent(s): f37c0a1

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +43 -0
main.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras.models import load_model
5
+ import cv2
6
+
7
+ # Custom metric (if any, otherwise comment this out)
8
+ def iou_metric(y_true, y_pred):
9
+ y_true = tf.cast(y_true > 0.5, tf.float32)
10
+ y_pred = tf.cast(y_pred > 0.5, tf.float32)
11
+ intersection = tf.reduce_sum(y_true * y_pred)
12
+ union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
13
+ return intersection / (union + 1e-7)
14
+
15
+ # Load the model
16
+ model = load_model("unet_mask_segmentation.h5", custom_objects={'iou_metric': iou_metric})
17
+
18
+ # Preprocess image
19
+ def preprocess_image(image, target_size=(256, 256)):
20
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
21
+ image = cv2.resize(image, target_size)
22
+ image = image / 255.0
23
+ return np.expand_dims(image, axis=0)
24
+
25
+ # Predict mask
26
+ def segment_image(input_image):
27
+ preprocessed = preprocess_image(input_image)
28
+ pred_mask = model.predict(preprocessed)[0]
29
+ binary_mask = (pred_mask > 0.5).astype(np.uint8) * 255
30
+ binary_mask = cv2.resize(binary_mask, (input_image.shape[1], input_image.shape[0]))
31
+ return binary_mask
32
+
33
+ # Gradio interface
34
+ interface = gr.Interface(
35
+ fn=segment_image,
36
+ inputs=gr.Image(type="numpy", label="Upload Image"),
37
+ outputs=gr.Image(type="numpy", label="Segmented Mask"),
38
+ title="Image Segmentation with U-Net",
39
+ description="Upload an image to see the segmentation mask predicted by the U-Net model."
40
+ )
41
+
42
+ if __name__ == "__main__":
43
+ interface.launch()