File size: 3,213 Bytes
9da561f
c705fd9
 
 
948a954
a268cd1
c705fd9
 
9da561f
 
c705fd9
9da561f
 
 
c705fd9
9da561f
c705fd9
 
9da561f
 
c705fd9
 
9da561f
f33a71a
9da561f
c705fd9
9da561f
c705fd9
9da561f
 
a268cd1
971d962
c705fd9
9da561f
 
 
971d962
3cddee9
 
e273c8a
 
 
 
 
 
c705fd9
971d962
 
 
 
 
 
 
 
 
a268cd1
0e9696b
971d962
3373fee
 
971d962
3373fee
9da561f
971d962
39a4a2b
9da561f
971d962
39a4a2b
 
af5a5ef
319b502
a38e328
 
 
 
 
 
 
 
 
 
 
8f5860b
9da561f
8f5860b
 
9da561f
 
a38e328
 
 
 
 
9da561f
 
319b502
9da561f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import gradio as gr
import numpy as np
import tensorflow as tf
from skimage.io import imsave
from skimage.transform import resize
import matplotlib.pyplot as plt

# Suppress TensorFlow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Standardize function
def standardize(img):
    # Standardization using adjusted standard deviation
    N = np.shape(img)[0] * np.shape(img)[1]
    s = np.maximum(np.std(img), 1.0 / np.sqrt(N))
    m = np.mean(img)
    img = (img - m) / s
    if np.ndim(img) == 2:
        img = np.dstack((img, img, img))
    return img

# Load model
filepath = './saved_model'
model = tf.keras.layers.TFSMLayer(filepath, call_endpoint='serving_default')

# Segmentation function
def FRFsegment(input_img):
    dims = (512, 512)
    w, h = input_img.shape[:2]

    # Standardize and resize the input image
    img = standardize(input_img)
    img = resize(img, dims, preserve_range=True, clip=True)
    img = np.expand_dims(img, axis=0)

    # Model prediction
    est_label_dict = model(img)

    # Print available keys to understand what we are dealing with
    print("Available keys in the output dictionary:", est_label_dict.keys())

    # Extract the actual tensor from the dictionary using a dynamic key lookup
    key = list(est_label_dict.keys())[0]  # Use the first available key
    est_label = est_label_dict[key]

    # Check the shape of the predicted output
    if len(est_label.shape) == 4 and est_label.shape[-1] > 1:
        # Multi-class segmentation: apply argmax
        mask = np.argmax(np.squeeze(est_label, axis=0), -1)
    else:
        # Binary segmentation or unexpected shape
        mask = np.squeeze(est_label, axis=0)

    # Resize the mask back to original input dimensions
    pred = resize(mask, (w, h), preserve_range=True, clip=True)

    # Convert prediction to uint8 format
    pred_uint8 = (pred / np.max(pred) * 255).astype(np.uint8)

    # Save predicted mask
    imsave("label.png", pred_uint8)

    # Overlay the segmentation on the original input image
    plt.clf()
    plt.imshow(input_img, cmap='gray')
    plt.imshow(pred, cmap='jet', alpha=0.4)  # Use 'jet' colormap to enhance visibility
    plt.axis("off")
    plt.margins(x=0, y=0)
    plt.savefig("overlay.png", dpi=300, bbox_inches="tight")

    # Read the overlay image to return it as an output
    overlay_img = plt.imread("overlay.png")

    return overlay_img, "label.png", "overlay.png"

# Prepare absolute paths for example images
example_dir = os.path.join(os.getcwd(), "examples")
example_images = [
    os.path.join(example_dir, "FRF_c1_snap_20191112160000.jpg"),
    os.path.join(example_dir, "FRF_c1_snap_20170101.jpg")
]

# Gradio Interface
title = "Segment beach imagery taken from a tower in Duck, NC, USA"
description = "This model segments beach imagery into 4 classes: vegetation, sand, coarse sand, and background (water + sky + buildings + people)"

FRFSegapp = gr.Interface(
    fn=FRFsegment,
    inputs=gr.Image(type="numpy"),
    outputs=[gr.Image(label="Overlay Image"), gr.File(label="Segmentation Mask Download"), gr.File(label="Overlay Image Download")],
    examples=example_images,
    title=title,
    description=description
)

FRFSegapp.launch()