jeffliulab commited on
Commit
66de2dc
·
verified ·
1 Parent(s): 8eddc5e

Initial deploy

Browse files
Files changed (6) hide show
  1. README.md +10 -5
  2. app.py +121 -0
  3. examples/car.jpg +0 -0
  4. examples/cat.jpg +0 -0
  5. examples/dog.jpg +0 -0
  6. requirements.txt +5 -0
README.md CHANGED
@@ -1,12 +1,17 @@
1
  ---
2
  title: Image Classification
3
- emoji: 🐨
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.11.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
  title: Image Classification
3
+ emoji: "\U0001F3AF"
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: "5.29.0"
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # Image Classification ResNet / ViT / MobileNet
14
+
15
+ Upload an image and compare predictions across different CNN and Transformer architectures.
16
+
17
+ **Course**: 100 Deep Learning ch2 — Convolutional Neural Networks
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image Classification — Compare ResNet-50 / ViT-base / MobileNetV3
3
+ Course: 100 Deep Learning ch2
4
+ """
5
+
6
+ import json
7
+ import urllib.request
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torchvision.models as models
12
+ import torchvision.transforms as T
13
+ import timm
14
+ import gradio as gr
15
+ from PIL import Image
16
+
17
+ device = torch.device("cpu")
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Models
21
+ # ---------------------------------------------------------------------------
22
+ model_registry = {
23
+ "ResNet-50": models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1),
24
+ "MobileNetV3-Small": models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1),
25
+ "ViT-Base (timm)": timm.create_model("vit_base_patch16_224", pretrained=True),
26
+ }
27
+ for m in model_registry.values():
28
+ m.eval().to(device)
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Preprocessing
32
+ # ---------------------------------------------------------------------------
33
+ preprocess = T.Compose([
34
+ T.Resize(256),
35
+ T.CenterCrop(224),
36
+ T.ToTensor(),
37
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
38
+ ])
39
+
40
+ # ImageNet labels
41
+ LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
42
+ try:
43
+ with urllib.request.urlopen(LABELS_URL) as resp:
44
+ LABELS = json.loads(resp.read().decode())
45
+ except Exception:
46
+ LABELS = [str(i) for i in range(1000)]
47
+
48
+
49
+ # ---------------------------------------------------------------------------
50
+ # Classify
51
+ # ---------------------------------------------------------------------------
52
+ def classify(image: Image.Image, model_name: str):
53
+ if image is None:
54
+ return {}
55
+ img = image.convert("RGB")
56
+ tensor = preprocess(img).unsqueeze(0).to(device)
57
+
58
+ model = model_registry[model_name]
59
+ with torch.no_grad():
60
+ logits = model(tensor)
61
+
62
+ probs = F.softmax(logits, dim=1)[0]
63
+ top5 = torch.topk(probs, 5)
64
+ return {LABELS[idx]: float(prob) for prob, idx in zip(top5.values, top5.indices)}
65
+
66
+
67
+ def compare_all(image: Image.Image):
68
+ """Run all 3 models and return results."""
69
+ if image is None:
70
+ return {}, {}, {}
71
+ r1 = classify(image, "ResNet-50")
72
+ r2 = classify(image, "MobileNetV3-Small")
73
+ r3 = classify(image, "ViT-Base (timm)")
74
+ return r1, r2, r3
75
+
76
+
77
+ # ---------------------------------------------------------------------------
78
+ # UI
79
+ # ---------------------------------------------------------------------------
80
+ with gr.Blocks(title="Image Classification") as demo:
81
+ gr.Markdown(
82
+ "# Image Classification\n"
83
+ "Upload an image to compare predictions from different architectures.\n"
84
+ "*Course: 100 Deep Learning ch2 — CNN*"
85
+ )
86
+
87
+ with gr.Tab("Single Model"):
88
+ with gr.Row():
89
+ with gr.Column():
90
+ img_single = gr.Image(type="pil", label="Upload Image")
91
+ model_choice = gr.Dropdown(
92
+ list(model_registry.keys()), value="ResNet-50", label="Model"
93
+ )
94
+ btn_single = gr.Button("Classify", variant="primary")
95
+ with gr.Column():
96
+ out_single = gr.Label(num_top_classes=5, label="Top-5 Predictions")
97
+
98
+ btn_single.click(classify, [img_single, model_choice], out_single)
99
+
100
+ with gr.Tab("Compare All Models"):
101
+ with gr.Row():
102
+ img_compare = gr.Image(type="pil", label="Upload Image")
103
+ btn_compare = gr.Button("Compare All", variant="primary")
104
+ with gr.Row():
105
+ out_resnet = gr.Label(num_top_classes=5, label="ResNet-50")
106
+ out_mobile = gr.Label(num_top_classes=5, label="MobileNetV3-Small")
107
+ out_vit = gr.Label(num_top_classes=5, label="ViT-Base")
108
+
109
+ btn_compare.click(compare_all, [img_compare], [out_resnet, out_mobile, out_vit])
110
+
111
+ gr.Examples(
112
+ examples=[
113
+ ["examples/cat.jpg"],
114
+ ["examples/dog.jpg"],
115
+ ["examples/car.jpg"],
116
+ ],
117
+ inputs=[img_single],
118
+ )
119
+
120
+ if __name__ == "__main__":
121
+ demo.launch()
examples/car.jpg ADDED
examples/cat.jpg ADDED
examples/dog.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=5.0.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ timm>=0.9.0
5
+ Pillow