aryamanpathak commited on
Commit
aaa7a88
·
verified ·
1 Parent(s): 1177d0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -41
app.py CHANGED
@@ -1,59 +1,55 @@
1
  import gradio as gr
2
- import cv2
3
  import numpy as np
4
  import tensorflow as tf
5
- from tensorflow.keras.models import load_model
6
 
7
- # Constants
8
- IMG_HEIGHT, IMG_WIDTH = 256, 256
9
 
10
- # Custom metrics
11
- def iou_metric(y_true, y_pred):
12
- y_pred = tf.cast(y_pred > 0.5, tf.float32)
13
- intersection = tf.reduce_sum(y_true * y_pred)
14
- union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
15
- return intersection / (union + 1e-7)
16
 
17
- def dice_metric(y_true, y_pred):
18
- y_pred = tf.cast(y_pred > 0.5, tf.float32)
19
- intersection = tf.reduce_sum(y_true * y_pred)
20
- return (2. * intersection) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + 1e-7)
21
 
22
- # Load model
23
- model = load_model("unet_mask_segmentation.h5", custom_objects={'iou_metric': iou_metric, 'dice_metric': dice_metric})
 
24
 
25
- # Predict mask function
26
- def predict(image):
27
- # Convert PIL to NumPy array
28
- img_np = np.array(image)
29
- img_resized = cv2.resize(img_np, (IMG_WIDTH, IMG_HEIGHT)) / 255.0
30
- img_input = np.expand_dims(img_resized, axis=0)
31
 
32
- # Predict mask
33
- mask_pred = model.predict(img_input)[0]
34
- mask_binary = (mask_pred > 0.5).astype(np.uint8) * 255
35
 
36
  # Create overlay
37
- overlay = img_resized.copy()
38
- overlay[mask_binary == 255] = [1.0, 0, 0] # Red overlay in normalized scale
 
 
 
 
39
 
40
- # Convert overlay and mask to displayable format
41
- overlay_display = (overlay * 255).astype(np.uint8)
42
- mask_display = mask_binary.astype(np.uint8)
43
 
44
- return overlay_display, mask_display
45
 
46
- # Gradio Interface
47
  interface = gr.Interface(
48
  fn=predict,
49
- inputs=gr.Image(type="pil"),
50
- outputs=[
51
- gr.Image(type="numpy", label="Overlay"),
52
- gr.Image(type="numpy", label="Predicted Mask")
53
- ],
54
- title="Face Mask Segmentation",
55
- description="Upload a face image to get the predicted segmentation mask using U-Net."
56
  )
57
 
58
- if __name__ == "__main__":
59
- interface.launch()
 
1
  import gradio as gr
 
2
  import numpy as np
3
  import tensorflow as tf
4
+ import cv2
5
 
6
+ # Load your trained Keras model
7
+ model = tf.keras.models.load_model("model.h5")
8
 
9
+ # Image preprocessing function (same as used during training)
10
+ def preprocess_image(img):
11
+ img_resized = cv2.resize(img, (256, 256))
12
+ img_normalized = img_resized / 255.0 # Normalize to 0-1
13
+ return img_normalized
 
14
 
15
+ # Prediction and overlay function
16
+ def predict(input_img):
17
+ # Ensure image is RGB and numpy array
18
+ img = np.array(input_img.convert("RGB"))
19
 
20
+ # Preprocess
21
+ preprocessed_img = preprocess_image(img)
22
+ input_tensor = np.expand_dims(preprocessed_img, axis=0) # Add batch dimension
23
 
24
+ # Model prediction
25
+ prediction = model.predict(input_tensor)[0] # Remove batch dim
 
 
 
 
26
 
27
+ # Post-processing mask
28
+ mask = (prediction > 0.5).astype(np.uint8) # Binary mask
29
+ mask_resized = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)
30
 
31
  # Create overlay
32
+ overlay = img.astype(np.float32) / 255.0 # Normalize input image
33
+ alpha = 0.5 # Transparency of overlay
34
+
35
+ # Create red mask in RGB format
36
+ red_mask = np.zeros_like(overlay)
37
+ red_mask[:, :, 0] = mask_resized # Red channel
38
 
39
+ # Alpha blend original image with red mask
40
+ blended = (1 - alpha) * overlay + alpha * red_mask
41
+ blended = np.clip(blended * 255, 0, 255).astype(np.uint8)
42
 
43
+ return blended
44
 
45
+ # Gradio interface
46
  interface = gr.Interface(
47
  fn=predict,
48
+ inputs=gr.Image(type="pil", label="Upload Image"),
49
+ outputs=gr.Image(type="numpy", label="Segmented Image"),
50
+ title="Image Segmentation App",
51
+ description="Upload an image and get the segmentation mask overlay using your trained model."
 
 
 
52
  )
53
 
54
+ # Launch Gradio app (enable public link for Hugging Face Spaces)
55
+ interface.launch(share=True)