File size: 4,307 Bytes
6cf5f32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Semantic Segmentation — Pixel-level classification with DeepLabV3
Courses: 100 ch3, 360 ch4
"""

import numpy as np
import torch
import torchvision.models.segmentation as seg_models
import torchvision.transforms as T
import gradio as gr
from PIL import Image

device = torch.device("cpu")

# Load DeepLabV3 with MobileNetV3 backbone (lightweight)
model = seg_models.deeplabv3_mobilenet_v3_large(
    weights=seg_models.DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT
).eval().to(device)

preprocess = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# PASCAL VOC class names (21 classes)
CLASS_NAMES = [
    "background", "aeroplane", "bicycle", "bird", "boat",
    "bottle", "bus", "car", "cat", "chair",
    "cow", "dining table", "dog", "horse", "motorbike",
    "person", "potted plant", "sheep", "sofa", "train",
    "tv/monitor",
]

# Color palette for each class
PALETTE = np.array([
    [0, 0, 0],        # background
    [128, 0, 0],      # aeroplane
    [0, 128, 0],      # bicycle
    [128, 128, 0],    # bird
    [0, 0, 128],      # boat
    [128, 0, 128],    # bottle
    [0, 128, 128],    # bus
    [128, 128, 128],  # car
    [64, 0, 0],       # cat
    [192, 0, 0],      # chair
    [64, 128, 0],     # cow
    [192, 128, 0],    # dining table
    [64, 0, 128],     # dog
    [192, 0, 128],    # horse
    [64, 128, 128],   # motorbike
    [192, 128, 128],  # person
    [0, 64, 0],       # potted plant
    [128, 64, 0],     # sheep
    [0, 192, 0],      # sofa
    [128, 192, 0],    # train
    [0, 64, 128],     # tv/monitor
], dtype=np.uint8)


def segment(image: Image.Image, display_mode: str):
    if image is None:
        return None, None, ""

    img = image.convert("RGB")
    w, h = img.size

    # Inference
    inp = preprocess(img).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(inp)["out"]
    pred = output.argmax(1).squeeze().cpu().numpy()

    # Resize prediction to original size
    pred_resized = np.array(
        Image.fromarray(pred.astype(np.uint8)).resize((w, h), Image.NEAREST)
    )

    # Color segmentation map
    seg_color = PALETTE[pred_resized]

    # Overlay
    img_np = np.array(img)
    overlay = (img_np * 0.5 + seg_color * 0.5).astype(np.uint8)

    # Detected classes
    unique_classes = np.unique(pred_resized)
    detected = [CLASS_NAMES[c] for c in unique_classes if c != 0]

    legend = "**Detected classes:**\n\n"
    for c in unique_classes:
        if c == 0:
            continue
        color = PALETTE[c]
        pixel_pct = np.sum(pred_resized == c) / pred_resized.size * 100
        color_hex = f"#{color[0]:02x}{color[1]:02x}{color[2]:02x}"
        legend += f"- <span style='color:{color_hex};font-weight:bold;'>██</span> {CLASS_NAMES[c]} ({pixel_pct:.1f}%)\n"

    if not detected:
        legend += "- No objects detected (background only)"

    if display_mode == "Overlay":
        return overlay, seg_color, legend
    elif display_mode == "Segmentation Only":
        return seg_color, seg_color, legend
    else:  # Side by Side
        return overlay, seg_color, legend


with gr.Blocks(title="Semantic Segmentation") as demo:
    gr.Markdown(
        "# Semantic Segmentation\n"
        "Upload an image to see pixel-level classification (21 PASCAL VOC classes).\n"
        "*Courses: 100 Deep Learning ch3, 360 Autonomous Driving ch4*"
    )

    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(type="pil", label="Upload Image")
            mode = gr.Radio(
                ["Overlay", "Segmentation Only", "Side by Side"],
                value="Overlay",
                label="Display Mode",
            )
            btn = gr.Button("Segment", variant="primary")

        with gr.Column(scale=2):
            with gr.Row():
                overlay_out = gr.Image(label="Result")
                seg_out = gr.Image(label="Segmentation Map")
            legend_md = gr.Markdown()

    btn.click(segment, [input_image, mode], [overlay_out, seg_out, legend_md])

    gr.Examples(
        examples=[
            ["examples/street.jpg", "Overlay"],
            ["examples/room.jpg", "Side by Side"],
        ],
        inputs=[input_image, mode],
    )

if __name__ == "__main__":
    demo.launch()