import numpy as np
from PIL import Image
import torch
from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation
import gradio as gr
# ---------------------------
# Load model & processor
# ---------------------------
model_checkpoint = "apple/deeplabv3-mobilevit-small"
image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)
model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()
palette = np.array(
[
[ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0],
[ 0, 0, 192], [192, 0, 192], [ 0, 192, 192], [192, 192, 192],
[128, 0, 0], [255, 0, 0], [128, 192, 0], [255, 192, 0],
[128, 0, 192], [255, 0, 192], [128, 192, 192], [255, 192, 192],
[ 0, 128, 0], [192, 128, 0], [ 0, 255, 0], [192, 255, 0],
[ 0, 128, 192]
],
dtype=np.uint8,
)
labels = [
"background","aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair",
"cow","diningtable","dog","horse","motorbike","person","pottedplant","sheep",
"sofa","train","tvmonitor",
]
# ---------------------------
# Prediction Function
# ---------------------------
def predict(image):
if image is None:
return None, None
with torch.no_grad():
inputs = image_processor(image, return_tensors="pt")
outputs = model(**inputs)
# Re-normalize back to uint8
resized = (
inputs["pixel_values"]
.numpy()
.squeeze()
.transpose(1, 2, 0)[..., ::-1] * 255
).astype(np.uint8)
# Class map
classes = outputs.logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
# Vectorized lookup table coloring
colored = palette[classes]
# Resize segmentation to match resized input
colored_img = Image.fromarray(colored).resize(
(resized.shape[1], resized.shape[0]),
resample=Image.Resampling.NEAREST
)
# Binary mask for overlay
mask = (classes != 0).astype(np.uint8) * 255
mask_img = Image.fromarray(mask).resize(
(resized.shape[1], resized.shape[0]),
resample=Image.Resampling.NEAREST
).convert("RGB")
resized_img = Image.fromarray(resized)
highlighted = Image.blend(resized_img, mask_img, 0.4)
return colored_img, highlighted
# ---------------------------
# Labels HTML
# ---------------------------
inverted = {0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20}
labels_html = " ".join(
f""
f"{labels[i]}"
for i in range(len(labels))
)
description = f"""
Semantic Segmentation with MobileViT + DeepLabV3
Model trained on Pascal VOC.
Classes:
{labels_html}
"""
article = """
Sources: