File size: 4,120 Bytes
66de2dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Image Classification — Compare ResNet-50 / ViT-base / MobileNetV3
Course: 100 Deep Learning ch2
"""

import json
import urllib.request

import torch
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as T
import timm
import gradio as gr
from PIL import Image

device = torch.device("cpu")

# ---------------------------------------------------------------------------
# Models
# ---------------------------------------------------------------------------
model_registry = {
    "ResNet-50": models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1),
    "MobileNetV3-Small": models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1),
    "ViT-Base (timm)": timm.create_model("vit_base_patch16_224", pretrained=True),
}
for m in model_registry.values():
    m.eval().to(device)

# ---------------------------------------------------------------------------
# Preprocessing
# ---------------------------------------------------------------------------
preprocess = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ImageNet labels
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
try:
    with urllib.request.urlopen(LABELS_URL) as resp:
        LABELS = json.loads(resp.read().decode())
except Exception:
    LABELS = [str(i) for i in range(1000)]


# ---------------------------------------------------------------------------
# Classify
# ---------------------------------------------------------------------------
def classify(image: Image.Image, model_name: str):
    if image is None:
        return {}
    img = image.convert("RGB")
    tensor = preprocess(img).unsqueeze(0).to(device)

    model = model_registry[model_name]
    with torch.no_grad():
        logits = model(tensor)

    probs = F.softmax(logits, dim=1)[0]
    top5 = torch.topk(probs, 5)
    return {LABELS[idx]: float(prob) for prob, idx in zip(top5.values, top5.indices)}


def compare_all(image: Image.Image):
    """Run all 3 models and return results."""
    if image is None:
        return {}, {}, {}
    r1 = classify(image, "ResNet-50")
    r2 = classify(image, "MobileNetV3-Small")
    r3 = classify(image, "ViT-Base (timm)")
    return r1, r2, r3


# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="Image Classification") as demo:
    gr.Markdown(
        "# Image Classification\n"
        "Upload an image to compare predictions from different architectures.\n"
        "*Course: 100 Deep Learning ch2 — CNN*"
    )

    with gr.Tab("Single Model"):
        with gr.Row():
            with gr.Column():
                img_single = gr.Image(type="pil", label="Upload Image")
                model_choice = gr.Dropdown(
                    list(model_registry.keys()), value="ResNet-50", label="Model"
                )
                btn_single = gr.Button("Classify", variant="primary")
            with gr.Column():
                out_single = gr.Label(num_top_classes=5, label="Top-5 Predictions")

        btn_single.click(classify, [img_single, model_choice], out_single)

    with gr.Tab("Compare All Models"):
        with gr.Row():
            img_compare = gr.Image(type="pil", label="Upload Image")
            btn_compare = gr.Button("Compare All", variant="primary")
        with gr.Row():
            out_resnet = gr.Label(num_top_classes=5, label="ResNet-50")
            out_mobile = gr.Label(num_top_classes=5, label="MobileNetV3-Small")
            out_vit = gr.Label(num_top_classes=5, label="ViT-Base")

        btn_compare.click(compare_all, [img_compare], [out_resnet, out_mobile, out_vit])

    gr.Examples(
        examples=[
            ["examples/cat.jpg"],
            ["examples/dog.jpg"],
            ["examples/car.jpg"],
        ],
        inputs=[img_single],
    )

if __name__ == "__main__":
    demo.launch()