Spaces:
Build error
Build error
File size: 3,213 Bytes
9da561f c705fd9 948a954 a268cd1 c705fd9 9da561f c705fd9 9da561f c705fd9 9da561f c705fd9 9da561f c705fd9 9da561f f33a71a 9da561f c705fd9 9da561f c705fd9 9da561f a268cd1 971d962 c705fd9 9da561f 971d962 3cddee9 e273c8a c705fd9 971d962 a268cd1 0e9696b 971d962 3373fee 971d962 3373fee 9da561f 971d962 39a4a2b 9da561f 971d962 39a4a2b af5a5ef 319b502 a38e328 8f5860b 9da561f 8f5860b 9da561f a38e328 9da561f 319b502 9da561f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os
import gradio as gr
import numpy as np
import tensorflow as tf
from skimage.io import imsave
from skimage.transform import resize
import matplotlib.pyplot as plt
# Suppress TensorFlow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Standardize function
def standardize(img):
# Standardization using adjusted standard deviation
N = np.shape(img)[0] * np.shape(img)[1]
s = np.maximum(np.std(img), 1.0 / np.sqrt(N))
m = np.mean(img)
img = (img - m) / s
if np.ndim(img) == 2:
img = np.dstack((img, img, img))
return img
# Load model
filepath = './saved_model'
model = tf.keras.layers.TFSMLayer(filepath, call_endpoint='serving_default')
# Segmentation function
def FRFsegment(input_img):
dims = (512, 512)
w, h = input_img.shape[:2]
# Standardize and resize the input image
img = standardize(input_img)
img = resize(img, dims, preserve_range=True, clip=True)
img = np.expand_dims(img, axis=0)
# Model prediction
est_label_dict = model(img)
# Print available keys to understand what we are dealing with
print("Available keys in the output dictionary:", est_label_dict.keys())
# Extract the actual tensor from the dictionary using a dynamic key lookup
key = list(est_label_dict.keys())[0] # Use the first available key
est_label = est_label_dict[key]
# Check the shape of the predicted output
if len(est_label.shape) == 4 and est_label.shape[-1] > 1:
# Multi-class segmentation: apply argmax
mask = np.argmax(np.squeeze(est_label, axis=0), -1)
else:
# Binary segmentation or unexpected shape
mask = np.squeeze(est_label, axis=0)
# Resize the mask back to original input dimensions
pred = resize(mask, (w, h), preserve_range=True, clip=True)
# Convert prediction to uint8 format
pred_uint8 = (pred / np.max(pred) * 255).astype(np.uint8)
# Save predicted mask
imsave("label.png", pred_uint8)
# Overlay the segmentation on the original input image
plt.clf()
plt.imshow(input_img, cmap='gray')
plt.imshow(pred, cmap='jet', alpha=0.4) # Use 'jet' colormap to enhance visibility
plt.axis("off")
plt.margins(x=0, y=0)
plt.savefig("overlay.png", dpi=300, bbox_inches="tight")
# Read the overlay image to return it as an output
overlay_img = plt.imread("overlay.png")
return overlay_img, "label.png", "overlay.png"
# Prepare absolute paths for example images
example_dir = os.path.join(os.getcwd(), "examples")
example_images = [
os.path.join(example_dir, "FRF_c1_snap_20191112160000.jpg"),
os.path.join(example_dir, "FRF_c1_snap_20170101.jpg")
]
# Gradio Interface
title = "Segment beach imagery taken from a tower in Duck, NC, USA"
description = "This model segments beach imagery into 4 classes: vegetation, sand, coarse sand, and background (water + sky + buildings + people)"
FRFSegapp = gr.Interface(
fn=FRFsegment,
inputs=gr.Image(type="numpy"),
outputs=[gr.Image(label="Overlay Image"), gr.File(label="Segmentation Mask Download"), gr.File(label="Overlay Image Download")],
examples=example_images,
title=title,
description=description
)
FRFSegapp.launch()
|