blanchon's picture
Update app.py
42e3db7 verified
import numpy as np
from PIL import Image
import torch
from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation
import gradio as gr
# ---------------------------
# Load model & processor
# ---------------------------
model_checkpoint = "apple/deeplabv3-mobilevit-small"
image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)
model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()
palette = np.array(
[
[ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0],
[ 0, 0, 192], [192, 0, 192], [ 0, 192, 192], [192, 192, 192],
[128, 0, 0], [255, 0, 0], [128, 192, 0], [255, 192, 0],
[128, 0, 192], [255, 0, 192], [128, 192, 192], [255, 192, 192],
[ 0, 128, 0], [192, 128, 0], [ 0, 255, 0], [192, 255, 0],
[ 0, 128, 192]
],
dtype=np.uint8,
)
labels = [
"background","aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair",
"cow","diningtable","dog","horse","motorbike","person","pottedplant","sheep",
"sofa","train","tvmonitor",
]
# ---------------------------
# Prediction Function
# ---------------------------
def predict(image):
if image is None:
return None, None
with torch.no_grad():
inputs = image_processor(image, return_tensors="pt")
outputs = model(**inputs)
# Re-normalize back to uint8
resized = (
inputs["pixel_values"]
.numpy()
.squeeze()
.transpose(1, 2, 0)[..., ::-1] * 255
).astype(np.uint8)
# Class map
classes = outputs.logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
# Vectorized lookup table coloring
colored = palette[classes]
# Resize segmentation to match resized input
colored_img = Image.fromarray(colored).resize(
(resized.shape[1], resized.shape[0]),
resample=Image.Resampling.NEAREST
)
# Binary mask for overlay
mask = (classes != 0).astype(np.uint8) * 255
mask_img = Image.fromarray(mask).resize(
(resized.shape[1], resized.shape[0]),
resample=Image.Resampling.NEAREST
).convert("RGB")
resized_img = Image.fromarray(resized)
highlighted = Image.blend(resized_img, mask_img, 0.4)
return colored_img, highlighted
# ---------------------------
# Labels HTML
# ---------------------------
inverted = {0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20}
labels_html = " ".join(
f"<span style='background-color: rgb{tuple(palette[i])}; "
f"color: {'white' if i in inverted else 'black'}; padding: 2px 4px;'>"
f"{labels[i]}</span>"
for i in range(len(labels))
)
description = f"""
Semantic Segmentation with MobileViT + DeepLabV3
Model trained on Pascal VOC.<br><br>
Classes:<br>{labels_html}
"""
article = """
<p>Sources:</p>
<ul>
<li><a href="https://arxiv.org/abs/2110.02178">MobileViT Paper</a></li>
<li><a href="https://github.com/apple/ml-cvnets">Apple ML-CVnets</a></li>
<li>Example images from <a href="https://huggingface.co/datasets/mishig/sample_images">Sample Images Dataset</a></li>
</ul>
"""
# ---------------------------
# Gradio App (Blocks)
# ---------------------------
with gr.Blocks(title="Semantic Segmentation with MobileViT") as demo:
gr.Markdown("# Semantic Segmentation with MobileViT & DeepLabV3")
gr.Markdown(description)
with gr.Row():
input_img = gr.Image(label="Upload Image", type="pil")
output_mask = gr.Image(label="Segmentation Mask")
output_overlay = gr.Image(label="Highlighted Overlay")
run_btn = gr.Button("Run")
run_btn.click(
predict,
inputs=input_img,
outputs=[output_mask, output_overlay]
)
gr.Markdown(article)
gr.Examples(
examples=[
["cat-3.jpg"],
["construction-site.jpg"],
["dog-cat.jpg"],
["football-match.jpg"],
],
inputs=input_img
)
demo.launch()