Spaces:
Sleeping
Sleeping
File size: 3,934 Bytes
c304fb7 42e3db7 c38af60 c304fb7 f1cff84 c38af60 42e3db7 c38af60 c304fb7 42e3db7 c304fb7 6a36cd0 c38af60 6a36cd0 f1cff84 c38af60 f1cff84 2c1d18b c38af60 8239775 c38af60 c304fb7 42e3db7 c304fb7 42e3db7 c38af60 42e3db7 c38af60 42e3db7 c38af60 42e3db7 c38af60 42e3db7 c38af60 42e3db7 c38af60 42e3db7 c38af60 42e3db7 c38af60 6a36cd0 c38af60 c304fb7 c38af60 42e3db7 c38af60 6a36cd0 c38af60 6a36cd0 c38af60 c304fb7 c38af60 f1cff84 c38af60 c304fb7 c38af60 2c1d18b c38af60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import numpy as np
from PIL import Image
import torch
from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation
import gradio as gr
# ---------------------------
# Load model & processor
# ---------------------------
model_checkpoint = "apple/deeplabv3-mobilevit-small"
image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)
model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()
palette = np.array(
[
[ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0],
[ 0, 0, 192], [192, 0, 192], [ 0, 192, 192], [192, 192, 192],
[128, 0, 0], [255, 0, 0], [128, 192, 0], [255, 192, 0],
[128, 0, 192], [255, 0, 192], [128, 192, 192], [255, 192, 192],
[ 0, 128, 0], [192, 128, 0], [ 0, 255, 0], [192, 255, 0],
[ 0, 128, 192]
],
dtype=np.uint8,
)
labels = [
"background","aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair",
"cow","diningtable","dog","horse","motorbike","person","pottedplant","sheep",
"sofa","train","tvmonitor",
]
# ---------------------------
# Prediction Function
# ---------------------------
def predict(image):
if image is None:
return None, None
with torch.no_grad():
inputs = image_processor(image, return_tensors="pt")
outputs = model(**inputs)
# Re-normalize back to uint8
resized = (
inputs["pixel_values"]
.numpy()
.squeeze()
.transpose(1, 2, 0)[..., ::-1] * 255
).astype(np.uint8)
# Class map
classes = outputs.logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)
# Vectorized lookup table coloring
colored = palette[classes]
# Resize segmentation to match resized input
colored_img = Image.fromarray(colored).resize(
(resized.shape[1], resized.shape[0]),
resample=Image.Resampling.NEAREST
)
# Binary mask for overlay
mask = (classes != 0).astype(np.uint8) * 255
mask_img = Image.fromarray(mask).resize(
(resized.shape[1], resized.shape[0]),
resample=Image.Resampling.NEAREST
).convert("RGB")
resized_img = Image.fromarray(resized)
highlighted = Image.blend(resized_img, mask_img, 0.4)
return colored_img, highlighted
# ---------------------------
# Labels HTML
# ---------------------------
inverted = {0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20}
labels_html = " ".join(
f"<span style='background-color: rgb{tuple(palette[i])}; "
f"color: {'white' if i in inverted else 'black'}; padding: 2px 4px;'>"
f"{labels[i]}</span>"
for i in range(len(labels))
)
description = f"""
Semantic Segmentation with MobileViT + DeepLabV3
Model trained on Pascal VOC.<br><br>
Classes:<br>{labels_html}
"""
article = """
<p>Sources:</p>
<ul>
<li><a href="https://arxiv.org/abs/2110.02178">MobileViT Paper</a></li>
<li><a href="https://github.com/apple/ml-cvnets">Apple ML-CVnets</a></li>
<li>Example images from <a href="https://huggingface.co/datasets/mishig/sample_images">Sample Images Dataset</a></li>
</ul>
"""
# ---------------------------
# Gradio App (Blocks)
# ---------------------------
with gr.Blocks(title="Semantic Segmentation with MobileViT") as demo:
gr.Markdown("# Semantic Segmentation with MobileViT & DeepLabV3")
gr.Markdown(description)
with gr.Row():
input_img = gr.Image(label="Upload Image", type="pil")
output_mask = gr.Image(label="Segmentation Mask")
output_overlay = gr.Image(label="Highlighted Overlay")
run_btn = gr.Button("Run")
run_btn.click(
predict,
inputs=input_img,
outputs=[output_mask, output_overlay]
)
gr.Markdown(article)
gr.Examples(
examples=[
["cat-3.jpg"],
["construction-site.jpg"],
["dog-cat.jpg"],
["football-match.jpg"],
],
inputs=input_img
)
demo.launch()
|