import os import gradio as gr import numpy as np import tensorflow as tf from skimage.io import imsave from skimage.transform import resize import matplotlib.pyplot as plt # Suppress TensorFlow warnings os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Standardize function def standardize(img): # Standardization using adjusted standard deviation N = np.shape(img)[0] * np.shape(img)[1] s = np.maximum(np.std(img), 1.0 / np.sqrt(N)) m = np.mean(img) img = (img - m) / s if np.ndim(img) == 2: img = np.dstack((img, img, img)) return img # Load model filepath = './saved_model' model = tf.keras.layers.TFSMLayer(filepath, call_endpoint='serving_default') # Segmentation function def FRFsegment(input_img): dims = (512, 512) w, h = input_img.shape[:2] # Standardize and resize the input image img = standardize(input_img) img = resize(img, dims, preserve_range=True, clip=True) img = np.expand_dims(img, axis=0) # Model prediction est_label_dict = model(img) # Print available keys to understand what we are dealing with print("Available keys in the output dictionary:", est_label_dict.keys()) # Extract the actual tensor from the dictionary using a dynamic key lookup key = list(est_label_dict.keys())[0] # Use the first available key est_label = est_label_dict[key] # Check the shape of the predicted output if len(est_label.shape) == 4 and est_label.shape[-1] > 1: # Multi-class segmentation: apply argmax mask = np.argmax(np.squeeze(est_label, axis=0), -1) else: # Binary segmentation or unexpected shape mask = np.squeeze(est_label, axis=0) # Resize the mask back to original input dimensions pred = resize(mask, (w, h), preserve_range=True, clip=True) # Convert prediction to uint8 format pred_uint8 = (pred / np.max(pred) * 255).astype(np.uint8) # Save predicted mask imsave("label.png", pred_uint8) # Overlay the segmentation on the original input image plt.clf() plt.imshow(input_img, cmap='gray') plt.imshow(pred, cmap='jet', alpha=0.4) # Use 'jet' colormap to enhance visibility plt.axis("off") plt.margins(x=0, y=0) plt.savefig("overlay.png", dpi=300, bbox_inches="tight") # Read the overlay image to return it as an output overlay_img = plt.imread("overlay.png") return overlay_img, "label.png", "overlay.png" # Prepare absolute paths for example images example_dir = os.path.join(os.getcwd(), "examples") example_images = [ os.path.join(example_dir, "FRF_c1_snap_20191112160000.jpg"), os.path.join(example_dir, "FRF_c1_snap_20170101.jpg") ] # Gradio Interface title = "Segment beach imagery taken from a tower in Duck, NC, USA" description = "This model segments beach imagery into 4 classes: vegetation, sand, coarse sand, and background (water + sky + buildings + people)" FRFSegapp = gr.Interface( fn=FRFsegment, inputs=gr.Image(type="numpy"), outputs=[gr.Image(label="Overlay Image"), gr.File(label="Segmentation Mask Download"), gr.File(label="Overlay Image Download")], examples=example_images, title=title, description=description ) FRFSegapp.launch()