abdulbasitdev's picture
Deploy MALUNet CVC-ClinicDB Gradio app
9ff8625 verified
"""Gradio Space for MALUNet polyp segmentation (CVC-ClinicDB).
Set the MODEL_REPO environment variable on the Space to your HF model repo
(e.g. "your-username/malunet-cvc"). If unset, falls back to a local
`best.pth` next to this file.
"""
import os
import gradio as gr
import numpy as np
from PIL import Image
from infer import load_model, overlay, predict_mask
MODEL_REPO = os.environ.get("MODEL_REPO", "best.pth")
print(f"loading weights from: {MODEL_REPO}")
_model = load_model(MODEL_REPO)
print("model ready")
def segment(image: Image.Image, threshold: float):
if image is None:
return None, None
mask = predict_mask(_model, image, threshold=float(threshold))
return Image.fromarray(mask), overlay(image, mask)
with gr.Blocks(title="MALUNet 路 CVC-ClinicDB Polyp Segmentation") as demo:
gr.Markdown(
"# MALUNet 路 CVC-ClinicDB polyp segmentation\n"
"Lightweight U-shape network (~0.18 M params) trained on CVC-ClinicDB. "
"Upload a colonoscopy frame; the model predicts a binary polyp mask."
)
with gr.Row():
with gr.Column():
inp = gr.Image(type="pil", label="Input image")
thr = gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Threshold")
btn = gr.Button("Segment", variant="primary")
with gr.Column():
mask_out = gr.Image(type="pil", label="Predicted mask")
ovl_out = gr.Image(type="pil", label="Overlay")
btn.click(segment, inputs=[inp, thr], outputs=[mask_out, ovl_out])
examples_dir = os.path.join(os.path.dirname(__file__), "examples")
if os.path.isdir(examples_dir):
ex = sorted(
os.path.join(examples_dir, f)
for f in os.listdir(examples_dir)
if f.lower().endswith((".png", ".jpg", ".jpeg", ".tif"))
)
if ex:
gr.Examples(examples=[[p, 0.5] for p in ex], inputs=[inp, thr])
if __name__ == "__main__":
demo.launch()