""" Image Classification — Compare ResNet-50 / ViT-base / MobileNetV3 Course: 100 Deep Learning ch2 """ import json import urllib.request import torch import torch.nn.functional as F import torchvision.models as models import torchvision.transforms as T import timm import gradio as gr from PIL import Image device = torch.device("cpu") # --------------------------------------------------------------------------- # Models # --------------------------------------------------------------------------- model_registry = { "ResNet-50": models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1), "MobileNetV3-Small": models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1), "ViT-Base (timm)": timm.create_model("vit_base_patch16_224", pretrained=True), } for m in model_registry.values(): m.eval().to(device) # --------------------------------------------------------------------------- # Preprocessing # --------------------------------------------------------------------------- preprocess = T.Compose([ T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ImageNet labels LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" try: with urllib.request.urlopen(LABELS_URL) as resp: LABELS = json.loads(resp.read().decode()) except Exception: LABELS = [str(i) for i in range(1000)] # --------------------------------------------------------------------------- # Classify # --------------------------------------------------------------------------- def classify(image: Image.Image, model_name: str): if image is None: return {} img = image.convert("RGB") tensor = preprocess(img).unsqueeze(0).to(device) model = model_registry[model_name] with torch.no_grad(): logits = model(tensor) probs = F.softmax(logits, dim=1)[0] top5 = torch.topk(probs, 5) return {LABELS[idx]: float(prob) for prob, idx in zip(top5.values, top5.indices)} def compare_all(image: Image.Image): """Run all 3 models and return results.""" if image is None: return {}, {}, {} r1 = classify(image, "ResNet-50") r2 = classify(image, "MobileNetV3-Small") r3 = classify(image, "ViT-Base (timm)") return r1, r2, r3 # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- with gr.Blocks(title="Image Classification") as demo: gr.Markdown( "# Image Classification\n" "Upload an image to compare predictions from different architectures.\n" "*Course: 100 Deep Learning ch2 — CNN*" ) with gr.Tab("Single Model"): with gr.Row(): with gr.Column(): img_single = gr.Image(type="pil", label="Upload Image") model_choice = gr.Dropdown( list(model_registry.keys()), value="ResNet-50", label="Model" ) btn_single = gr.Button("Classify", variant="primary") with gr.Column(): out_single = gr.Label(num_top_classes=5, label="Top-5 Predictions") btn_single.click(classify, [img_single, model_choice], out_single) with gr.Tab("Compare All Models"): with gr.Row(): img_compare = gr.Image(type="pil", label="Upload Image") btn_compare = gr.Button("Compare All", variant="primary") with gr.Row(): out_resnet = gr.Label(num_top_classes=5, label="ResNet-50") out_mobile = gr.Label(num_top_classes=5, label="MobileNetV3-Small") out_vit = gr.Label(num_top_classes=5, label="ViT-Base") btn_compare.click(compare_all, [img_compare], [out_resnet, out_mobile, out_vit]) gr.Examples( examples=[ ["examples/cat.jpg"], ["examples/dog.jpg"], ["examples/car.jpg"], ], inputs=[img_single], ) if __name__ == "__main__": demo.launch()