FRF_Heavies / app.py
ebgoldstein's picture
Update app.py
8f400fc
import gradio as gr
import numpy as np
import tensorflow as tf
from skimage.io import imsave
from skimage.transform import resize
import matplotlib.pyplot as plt
#from SegZoo
def standardize(img):
#standardization using adjusted standard deviation
N = np.shape(img)[0] * np.shape(img)[1]
s = np.maximum(np.std(img), 1.0/np.sqrt(N))
m = np.mean(img)
img = (img - m) / s
del m, s, N
#
if np.ndim(img)==2:
img = np.dstack((img,img,img))
return img
#load model
filepath = './saved_model'
model = tf.keras.models.load_model(filepath, compile = True)
model.compile
#segmentation
def FRFsegment(input_img):
dims=(512,512)
w = input_img.shape[0]
h = input_img.shape[1]
print(w)
print(h)
img = standardize(input_img)
img = resize(img, dims, preserve_range=True, clip=True)
img = np.expand_dims(img,axis=0)
est_label = model.predict(img)
# # Test Time AUgmentation
# est_label2 = np.flipud(model.predict((np.flipud(img)), batch_size=1))
# est_label3 = np.fliplr(model.predict((np.fliplr(img)), batch_size=1))
# est_label4 = np.flipud(np.fliplr(model.predict((np.flipud(np.fliplr(img))))))
# #soft voting - sum the softmax scores to return the new TTA estimated softmax scores
# pred = est_label + est_label2 + est_label3 + est_label4
# est_label = pred
mask = np.argmax(np.squeeze(est_label, axis=0),-1)
pred = resize(mask, (w, h), preserve_range=True, clip=True)
imsave("label.png", pred)
#overlay plot
plt.clf()
plt.imshow(input_img,cmap='gray')
plt.imshow(pred, alpha=0.4)
plt.axis("off")
plt.margins(x=0, y=0)
plt.savefig("overlay.png", dpi=300, bbox_inches="tight")
return plt, "label.png", "overlay.png"
title = "Segment beach imagery taken from a tower in Duck, NC, USA"
description = "This model segments beach imagery into 4 classes: vegetation, sand, heavy minerals, and background (water + sky + buildings + people)"
examples = [['examples/FRF_c1_snap_20191112160000.jpg'], ['examples/FRF_c1_snap_20170401.jpg']]
FRFSegapp = gr.Interface(FRFsegment, gr.inputs.Image(), ['plot',gr.outputs.File(),gr.outputs.File()], examples=examples, title = title, description = description, theme = "grass").launch()