File size: 3,934 Bytes
c304fb7
 
 
42e3db7
c38af60
c304fb7
f1cff84
c38af60
42e3db7
c38af60
c304fb7
42e3db7
 
c304fb7
 
6a36cd0
 
 
 
 
 
 
 
 
c38af60
 
6a36cd0
f1cff84
c38af60
 
 
f1cff84
2c1d18b
 
c38af60
 
 
8239775
c38af60
 
 
c304fb7
42e3db7
c304fb7
 
42e3db7
c38af60
42e3db7
 
 
 
c38af60
 
42e3db7
c38af60
 
42e3db7
c38af60
 
42e3db7
c38af60
 
 
 
 
42e3db7
c38af60
 
 
 
 
 
 
 
 
 
 
 
 
42e3db7
c38af60
42e3db7
 
c38af60
 
 
 
 
 
 
 
 
 
 
 
6a36cd0
c38af60
 
 
 
 
 
 
 
c304fb7
 
c38af60
42e3db7
c38af60
 
 
 
6a36cd0
c38af60
 
 
 
6a36cd0
c38af60
c304fb7
c38af60
 
 
 
 
f1cff84
c38af60
c304fb7
c38af60
 
 
 
 
 
 
 
 
2c1d18b
c38af60
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import numpy as np
from PIL import Image
import torch
from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation
import gradio as gr


# ---------------------------
# Load model & processor
# ---------------------------
model_checkpoint = "apple/deeplabv3-mobilevit-small"

image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)
model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval()

palette = np.array(
[
    [  0,   0,   0], [192,   0,   0], [  0, 192,   0], [192, 192,   0],
    [  0,   0, 192], [192,   0, 192], [  0, 192, 192], [192, 192, 192],
    [128,   0,   0], [255,   0,   0], [128, 192,   0], [255, 192,   0],
    [128,   0, 192], [255,   0, 192], [128, 192, 192], [255, 192, 192],
    [  0, 128,   0], [192, 128,   0], [  0, 255,   0], [192, 255,   0],
    [  0, 128, 192]
],
dtype=np.uint8,
)

labels = [
    "background","aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair",
    "cow","diningtable","dog","horse","motorbike","person","pottedplant","sheep",
    "sofa","train","tvmonitor",
]


# ---------------------------
# Prediction Function
# ---------------------------
def predict(image):
    if image is None:
        return None, None

    with torch.no_grad():
        inputs = image_processor(image, return_tensors="pt")
        outputs = model(**inputs)

    # Re-normalize back to uint8
    resized = (
        inputs["pixel_values"]
        .numpy()
        .squeeze()
        .transpose(1, 2, 0)[..., ::-1] * 255
    ).astype(np.uint8)

    # Class map
    classes = outputs.logits.argmax(1).squeeze().cpu().numpy().astype(np.uint8)

    # Vectorized lookup table coloring
    colored = palette[classes]

    # Resize segmentation to match resized input
    colored_img = Image.fromarray(colored).resize(
        (resized.shape[1], resized.shape[0]),
        resample=Image.Resampling.NEAREST
    )

    # Binary mask for overlay
    mask = (classes != 0).astype(np.uint8) * 255
    mask_img = Image.fromarray(mask).resize(
        (resized.shape[1], resized.shape[0]),
        resample=Image.Resampling.NEAREST
    ).convert("RGB")

    resized_img = Image.fromarray(resized)
    highlighted = Image.blend(resized_img, mask_img, 0.4)

    return colored_img, highlighted


# ---------------------------
# Labels HTML
# ---------------------------
inverted = {0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20}

labels_html = " ".join(
    f"<span style='background-color: rgb{tuple(palette[i])}; "
    f"color: {'white' if i in inverted else 'black'}; padding: 2px 4px;'>"
    f"{labels[i]}</span>"
    for i in range(len(labels))
)

description = f"""
Semantic Segmentation with MobileViT + DeepLabV3  
Model trained on Pascal VOC.<br><br>
Classes:<br>{labels_html}
"""

article = """
<p>Sources:</p>
<ul>
<li><a href="https://arxiv.org/abs/2110.02178">MobileViT Paper</a></li>
<li><a href="https://github.com/apple/ml-cvnets">Apple ML-CVnets</a></li>
<li>Example images from <a href="https://huggingface.co/datasets/mishig/sample_images">Sample Images Dataset</a></li>
</ul>
"""


# ---------------------------
# Gradio App (Blocks)
# ---------------------------
with gr.Blocks(title="Semantic Segmentation with MobileViT") as demo:
    gr.Markdown("# Semantic Segmentation with MobileViT & DeepLabV3")
    gr.Markdown(description)

    with gr.Row():
        input_img = gr.Image(label="Upload Image", type="pil")
        output_mask = gr.Image(label="Segmentation Mask")
        output_overlay = gr.Image(label="Highlighted Overlay")

    run_btn = gr.Button("Run")

    run_btn.click(
        predict,
        inputs=input_img,
        outputs=[output_mask, output_overlay]
    )

    gr.Markdown(article)

    gr.Examples(
        examples=[
            ["cat-3.jpg"],
            ["construction-site.jpg"],
            ["dog-cat.jpg"],
            ["football-match.jpg"],
        ],
        inputs=input_img
    )

demo.launch()