Moez B
examples
39b09ef
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import matplotlib.cm as cm
#load all models at the start
unet_path = hf_hub_download(repo_id="treephones/tool-seg-unet", filename="1_simple_unet.keras")
resnet_path = hf_hub_download(repo_id="treephones/tool-seg-resnet50-unet", filename="2_resnet50_unet.keras")
resnet_path_generalized = hf_hub_download(repo_id="treephones/tool-seg-resnet50-generalized", filename="2_resnet50_unet_generalized.keras")
fcn_path = hf_hub_download(repo_id="treephones/tool-seg-fcn-8s", filename="3_fcn_8s.keras")
unet = tf.keras.models.load_model(unet_path, compile=False)
fcn_8s = tf.keras.models.load_model(fcn_path, compile=False)
rn50_unet = tf.keras.models.load_model(resnet_path, compile=False)
rn50_unet_generalized = tf.keras.models.load_model(resnet_path_generalized, compile=False)
#preprocess
def preprocess_image(path):
t1 = tf.io.read_file(path)
t1 = tf.image.decode_png(t1, channels=3)
t1 = tf.cast(t1, tf.float32) / 255.0
t1 = tf.image.resize(t1, (256, 256), method=tf.image.ResizeMethod.BILINEAR)
t1 = tf.expand_dims(t1, axis=0)
return t1
#make output masks look nice
def apply_colormap(mask):
colormap = cm.get_cmap('viridis')
colored = colormap(mask)[:, :, :3]
return (colored * 255).astype(np.uint8)
#process output from model for display on gradio
def postprocess_output(output):
output = np.squeeze(output)
if output.ndim == 3 and output.shape[-1] == 1:
output = output[:, :, 0]
output = np.clip(output, 0, 1)
return apply_colormap(output)
#predict
def predict_all_models(img):
img = preprocess_image(img)
out1 = unet.predict(img)
out2 = fcn_8s.predict(img)
out3 = rn50_unet.predict(img)
out4 = rn50_unet_generalized.predict(img)
print("Output shapes:", out1.shape, out2.shape, out3.shape, out4.shape)
out1_img = postprocess_output(out1)
out2_img = postprocess_output(out2)
out3_img = postprocess_output(out3)
out4_img = postprocess_output(out4)
original_img = img[0].numpy()
original_img = (original_img * 255).astype(np.uint8)
return original_img, out1_img, out2_img, out3_img, out4_img
#gradio stuff
with gr.Blocks() as demo:
gr.Markdown("# Generalized Surgical Instrument Segmentation [All Models]")
gr.Markdown("Upload a single image to see how each segmentation model performs. The first row of output shows the non-generalized basic models. The second row of output shows the better model (U-Net with pretrained ResNet50 encoder) non-generalized and generalized. Generalized = trained on both Laparoscopic Cholecystectomy data and Cataract operation data, Non-generalized = trained only on Cholecystectomy data. This is the free tier so compute power is limited. The only time you may see an error is if multiple people are using it at once. Just refresh and try again if that does happen. By Moez B. and Steven K.")
gr.Markdown("Test images to copy and paste in if you don't want to find your own:")
gr.Markdown("## Laparoscopic Cholecystectomy:")
gr.Markdown("1. https://medtube.net/images/min/bcdfd6b4e85759ae7dacb7778b679aa6/459/345/1")
gr.Markdown("2. https://www.sages.org/wp-content/uploads/2019/09/fecXdNs6rp0.jpg")
gr.Markdown("3. https://medtube.net/images/min/6041ca08cd4d11099559472fa4cd992f/459/345/1")
gr.Markdown("## Cataract Surgery:")
gr.Markdown("1. https://core4-cms.imgix.net/Screen%20Shot%202021-12-13%20at%203.41.12%20PM_1639428086.png")
gr.Markdown("2. https://core4-cms.imgix.net/Screen%20Shot%202021-11-30%20at%2010.31.36%20AM_1638286314.png")
gr.Markdown("3. https://i.ytimg.com/vi/iFQ9-i41Z0I/sddefault.jpg")
with gr.Row():
input_image = gr.Image(label="Upload an Image", type="filepath")
with gr.Row():
output1 = gr.Image(label="Original Image (Resized)")
output2 = gr.Image(label="Simple U-Net")
output3 = gr.Image(label="FCN-8s")
with gr.Row():
output4 = gr.Image(label="ResNet50 + U-Net [Non-generalized]")
output5 = gr.Image(label="ResNet50 + U-Net [Generalized]")
input_image.change(
fn=predict_all_models,
inputs=input_image,
outputs=[output1, output2, output3, output4, output5]
)
demo.launch(share=True)