Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| from skimage.io import imsave | |
| from skimage.transform import resize | |
| import matplotlib.pyplot as plt | |
| #from SegZoo | |
| def standardize(img): | |
| #standardization using adjusted standard deviation | |
| N = np.shape(img)[0] * np.shape(img)[1] | |
| s = np.maximum(np.std(img), 1.0/np.sqrt(N)) | |
| m = np.mean(img) | |
| img = (img - m) / s | |
| del m, s, N | |
| # | |
| if np.ndim(img)==2: | |
| img = np.dstack((img,img,img)) | |
| return img | |
| #load model | |
| filepath = './saved_model' | |
| model = tf.keras.models.load_model(filepath, compile = True) | |
| model.compile | |
| #segmentation | |
| def FRFsegment(input_img): | |
| dims=(512,512) | |
| w = input_img.shape[0] | |
| h = input_img.shape[1] | |
| print(w) | |
| print(h) | |
| img = standardize(input_img) | |
| img = resize(img, dims, preserve_range=True, clip=True) | |
| img = np.expand_dims(img,axis=0) | |
| est_label = model.predict(img) | |
| # # Test Time AUgmentation | |
| # est_label2 = np.flipud(model.predict((np.flipud(img)), batch_size=1)) | |
| # est_label3 = np.fliplr(model.predict((np.fliplr(img)), batch_size=1)) | |
| # est_label4 = np.flipud(np.fliplr(model.predict((np.flipud(np.fliplr(img)))))) | |
| # #soft voting - sum the softmax scores to return the new TTA estimated softmax scores | |
| # pred = est_label + est_label2 + est_label3 + est_label4 | |
| # est_label = pred | |
| mask = np.argmax(np.squeeze(est_label, axis=0),-1) | |
| pred = resize(mask, (w, h), preserve_range=True, clip=True) | |
| imsave("label.png", pred) | |
| #overlay plot | |
| plt.clf() | |
| plt.imshow(input_img,cmap='gray') | |
| plt.imshow(pred, alpha=0.4) | |
| plt.axis("off") | |
| plt.margins(x=0, y=0) | |
| plt.savefig("overlay.png", dpi=300, bbox_inches="tight") | |
| return plt, "label.png", "overlay.png" | |
| title = "Segment beach imagery taken from a tower in Duck, NC, USA" | |
| description = "This model segments beach imagery into 4 classes: vegetation, sand, heavy minerals, and background (water + sky + buildings + people)" | |
| examples = [['examples/FRF_c1_snap_20191112160000.jpg'], ['examples/FRF_c1_snap_20170401.jpg']] | |
| FRFSegapp = gr.Interface(FRFsegment, gr.inputs.Image(), ['plot',gr.outputs.File(),gr.outputs.File()], examples=examples, title = title, description = description, theme = "grass").launch() | |