import numpy as np from PIL import Image import torch from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation import gradio as gr # --------------------------- # Load model & processor # --------------------------- model_checkpoint = "apple/deeplabv3-mobilevit-small" image_processor = AutoImageProcessor.from_pretrained(model_checkpoint) model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval() palette = np.array( [ [ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0], [ 0, 0, 192], [192, 0, 192], [ 0, 192, 192], [192, 192, 192], [128, 0, 0], [255, 0, 0], [128, 192, 0], [255, 192, 0], [128, 0, 192], [255, 0, 192], [128, 192, 192], [255, 192, 192], [ 0, 128, 0], [192, 128, 0], [ 0, 255, 0], [192, 255, 0], [ 0, 128, 192] ], dtype=np.uint8, ) labels = [ "background","aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair", "cow","diningtable","dog","horse","motorbike","person","pottedplant","sheep", "sofa","train","tvmonitor", ] # --------------------------- # Prediction Function # --------------------------- def predict(image): if image is None: return None, None with torch.no_grad(): inputs = image_processor(image, return_tensors="pt") outputs = model(**inputs) # Re-normalize back to uint8 resized = ( inputs["pixel_values"] .numpy() .squeeze() .transpose(1, 2, 0)[..., ::-1] * 255 ).astype(np.uint8) # Class map classes = outputs.logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8) # Vectorized lookup table coloring colored = palette[classes] # Resize segmentation to match resized input colored_img = Image.fromarray(colored).resize( (resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST ) # Binary mask for overlay mask = (classes != 0).astype(np.uint8) * 255 mask_img = Image.fromarray(mask).resize( (resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST ).convert("RGB") resized_img = Image.fromarray(resized) highlighted = Image.blend(resized_img, mask_img, 0.4) return colored_img, highlighted # --------------------------- # Labels HTML # --------------------------- inverted = {0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20} labels_html = " ".join( f"" f"{labels[i]}" for i in range(len(labels)) ) description = f""" Semantic Segmentation with MobileViT + DeepLabV3 Model trained on Pascal VOC.

Classes:
{labels_html} """ article = """

Sources:

""" # --------------------------- # Gradio App (Blocks) # --------------------------- with gr.Blocks(title="Semantic Segmentation with MobileViT") as demo: gr.Markdown("# Semantic Segmentation with MobileViT & DeepLabV3") gr.Markdown(description) with gr.Row(): input_img = gr.Image(label="Upload Image", type="pil") output_mask = gr.Image(label="Segmentation Mask") output_overlay = gr.Image(label="Highlighted Overlay") run_btn = gr.Button("Run") run_btn.click( predict, inputs=input_img, outputs=[output_mask, output_overlay] ) gr.Markdown(article) gr.Examples( examples=[ ["cat-3.jpg"], ["construction-site.jpg"], ["dog-cat.jpg"], ["football-match.jpg"], ], inputs=input_img ) demo.launch()