| import gradio as gr |
| import tensorflow as tf |
| import numpy as np |
| from PIL import Image |
| from huggingface_hub import hf_hub_download |
| import matplotlib.cm as cm |
|
|
| |
| unet_path = hf_hub_download(repo_id="treephones/tool-seg-unet", filename="1_simple_unet.keras") |
| resnet_path = hf_hub_download(repo_id="treephones/tool-seg-resnet50-unet", filename="2_resnet50_unet.keras") |
| resnet_path_generalized = hf_hub_download(repo_id="treephones/tool-seg-resnet50-generalized", filename="2_resnet50_unet_generalized.keras") |
| fcn_path = hf_hub_download(repo_id="treephones/tool-seg-fcn-8s", filename="3_fcn_8s.keras") |
|
|
| unet = tf.keras.models.load_model(unet_path, compile=False) |
| fcn_8s = tf.keras.models.load_model(fcn_path, compile=False) |
| rn50_unet = tf.keras.models.load_model(resnet_path, compile=False) |
| rn50_unet_generalized = tf.keras.models.load_model(resnet_path_generalized, compile=False) |
|
|
| |
| def preprocess_image(path): |
| t1 = tf.io.read_file(path) |
| t1 = tf.image.decode_png(t1, channels=3) |
| t1 = tf.cast(t1, tf.float32) / 255.0 |
| t1 = tf.image.resize(t1, (256, 256), method=tf.image.ResizeMethod.BILINEAR) |
| t1 = tf.expand_dims(t1, axis=0) |
|
|
| return t1 |
|
|
| |
| def apply_colormap(mask): |
| colormap = cm.get_cmap('viridis') |
| colored = colormap(mask)[:, :, :3] |
| return (colored * 255).astype(np.uint8) |
|
|
| |
| def postprocess_output(output): |
| output = np.squeeze(output) |
| if output.ndim == 3 and output.shape[-1] == 1: |
| output = output[:, :, 0] |
| output = np.clip(output, 0, 1) |
| return apply_colormap(output) |
|
|
| |
| def predict_all_models(img): |
| img = preprocess_image(img) |
|
|
| out1 = unet.predict(img) |
| out2 = fcn_8s.predict(img) |
| out3 = rn50_unet.predict(img) |
| out4 = rn50_unet_generalized.predict(img) |
|
|
| print("Output shapes:", out1.shape, out2.shape, out3.shape, out4.shape) |
|
|
| out1_img = postprocess_output(out1) |
| out2_img = postprocess_output(out2) |
| out3_img = postprocess_output(out3) |
| out4_img = postprocess_output(out4) |
|
|
| original_img = img[0].numpy() |
| original_img = (original_img * 255).astype(np.uint8) |
|
|
| return original_img, out1_img, out2_img, out3_img, out4_img |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# Generalized Surgical Instrument Segmentation [All Models]") |
| gr.Markdown("Upload a single image to see how each segmentation model performs. The first row of output shows the non-generalized basic models. The second row of output shows the better model (U-Net with pretrained ResNet50 encoder) non-generalized and generalized. Generalized = trained on both Laparoscopic Cholecystectomy data and Cataract operation data, Non-generalized = trained only on Cholecystectomy data. This is the free tier so compute power is limited. The only time you may see an error is if multiple people are using it at once. Just refresh and try again if that does happen. By Moez B. and Steven K.") |
|
|
| gr.Markdown("Test images to copy and paste in if you don't want to find your own:") |
| gr.Markdown("## Laparoscopic Cholecystectomy:") |
| gr.Markdown("1. https://medtube.net/images/min/bcdfd6b4e85759ae7dacb7778b679aa6/459/345/1") |
| gr.Markdown("2. https://www.sages.org/wp-content/uploads/2019/09/fecXdNs6rp0.jpg") |
| gr.Markdown("3. https://medtube.net/images/min/6041ca08cd4d11099559472fa4cd992f/459/345/1") |
|
|
| gr.Markdown("## Cataract Surgery:") |
| gr.Markdown("1. https://core4-cms.imgix.net/Screen%20Shot%202021-12-13%20at%203.41.12%20PM_1639428086.png") |
| gr.Markdown("2. https://core4-cms.imgix.net/Screen%20Shot%202021-11-30%20at%2010.31.36%20AM_1638286314.png") |
| gr.Markdown("3. https://i.ytimg.com/vi/iFQ9-i41Z0I/sddefault.jpg") |
|
|
|
|
| with gr.Row(): |
| input_image = gr.Image(label="Upload an Image", type="filepath") |
|
|
| with gr.Row(): |
| output1 = gr.Image(label="Original Image (Resized)") |
| output2 = gr.Image(label="Simple U-Net") |
| output3 = gr.Image(label="FCN-8s") |
|
|
| with gr.Row(): |
| output4 = gr.Image(label="ResNet50 + U-Net [Non-generalized]") |
| output5 = gr.Image(label="ResNet50 + U-Net [Generalized]") |
|
|
| input_image.change( |
| fn=predict_all_models, |
| inputs=input_image, |
| outputs=[output1, output2, output3, output4, output5] |
| ) |
|
|
| demo.launch(share=True) |
|
|