from pathlib import Path import gradio as gr import numpy as np import onnxruntime as ort from huggingface_hub import hf_hub_download from PIL import Image MODEL_REPO = "ReaganWZY/DepthPolyp" ONNX_FILENAME = "DepthPolyp_Kvasir.onnx" IMAGE_SIZE = 224 EXAMPLES_DIR = Path("samples/kvasir/images") SESSION = None INPUT_NAME = None def get_session(): global SESSION, INPUT_NAME if SESSION is None: model_path = hf_hub_download(repo_id=MODEL_REPO, filename=ONNX_FILENAME) SESSION = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) INPUT_NAME = SESSION.get_inputs()[0].name return SESSION, INPUT_NAME def preprocess(image: Image.Image): image = image.convert("RGB") original_size = image.size resized = image.resize((IMAGE_SIZE, IMAGE_SIZE), Image.BILINEAR) array = np.asarray(resized).astype(np.float32) / 255.0 tensor = np.transpose(array, (2, 0, 1))[None, ...] return image, original_size, tensor def to_grayscale(probability: np.ndarray, size): probability = np.clip(probability, 0.0, 1.0) image = Image.fromarray((probability * 255).astype(np.uint8), mode="L") return image.resize(size, Image.BILINEAR) def colorize_purple_yellow(probability: np.ndarray, size): probability = np.clip(probability, 0.0, 1.0) stops = np.array( [ [38, 5, 84], [86, 33, 132], [141, 48, 140], [203, 71, 119], [245, 135, 48], [252, 231, 37], ], dtype=np.float32, ) scaled = probability * (len(stops) - 1) lower = np.floor(scaled).astype(np.int32) upper = np.clip(lower + 1, 0, len(stops) - 1) alpha = (scaled - lower)[..., None] colored = stops[lower] * (1.0 - alpha) + stops[upper] * alpha image = Image.fromarray(colored.astype(np.uint8), mode="RGB") return image.resize(size, Image.BILINEAR) def make_overlay(image: Image.Image, mask_probability: Image.Image): base = image.convert("RGBA") mask_array = np.asarray(mask_probability).astype(np.float32) / 255.0 color = np.zeros((mask_array.shape[0], mask_array.shape[1], 4), dtype=np.uint8) color[..., 0] = 252 color[..., 1] = 231 color[..., 2] = 37 color[..., 3] = (mask_array * 155).astype(np.uint8) return Image.alpha_composite(base, Image.fromarray(color, mode="RGBA")).convert("RGB") def run_depthpolyp(image: Image.Image, threshold: float): if image is None: raise gr.Error("Please upload an image first.") session, input_name = get_session() original, original_size, tensor = preprocess(image) segmentation, depth = session.run(None, {input_name: tensor}) seg_prob = segmentation[0, 0] depth_prob = depth[0, 0] seg_image = to_grayscale(seg_prob, original_size) mask = seg_image.point(lambda value: 255 if value >= int(threshold * 255) else 0) depth_image = colorize_purple_yellow(depth_prob, original_size) overlay = make_overlay(original, seg_image) return overlay, mask, depth_image def example_paths(): if not EXAMPLES_DIR.exists(): return [] return [[str(path), 0.3] for path in sorted(EXAMPLES_DIR.glob("*")) if path.suffix.lower() in {".jpg", ".jpeg", ".png"}] with gr.Blocks(title="DepthPolyp Demo") as demo: gr.Markdown( """ # DepthPolyp CPU ONNX demo for pseudo-depth guided colonoscopic polyp segmentation. """ ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Colonoscopy image") threshold = gr.Slider(0.05, 0.95, value=0.3, step=0.05, label="Mask threshold") run_button = gr.Button("Run inference", variant="primary") with gr.Column(scale=2): with gr.Row(): overlay = gr.Image(type="pil", label="Segmentation overlay") mask = gr.Image(type="pil", label="Binary mask") depth = gr.Image(type="pil", label="Pseudo-depth") run_button.click( fn=run_depthpolyp, inputs=[input_image, threshold], outputs=[overlay, mask, depth], ) input_image.change( fn=run_depthpolyp, inputs=[input_image, threshold], outputs=[overlay, mask, depth], ) threshold.change( fn=run_depthpolyp, inputs=[input_image, threshold], outputs=[overlay, mask, depth], ) gr.Examples( examples=example_paths(), inputs=[input_image, threshold], outputs=[overlay, mask, depth], fn=run_depthpolyp, cache_examples=False, ) gr.Markdown( "Model: [ReaganWZY/DepthPolyp](https://huggingface.co/ReaganWZY/DepthPolyp). " "For research use only; not for clinical diagnosis." ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)