erdi28 commited on
Commit
61e3cb1
·
1 Parent(s): 68a0c01

Upload app_demo.ipynb

Browse files
Files changed (1) hide show
  1. app_demo.ipynb +243 -0
app_demo.ipynb ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "097f69c4",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import gradio\n",
11
+ "import torch\n",
12
+ "from torchvision import transforms\n",
13
+ "import torch.nn.functional as F\n",
14
+ "import timm\n",
15
+ "import torch.nn as nn\n",
16
+ "from PIL import Image"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 5,
22
+ "id": "065d0efd",
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "idx_to_class = {0: 'adidas', 1: 'converse', 2: 'new-balance', 3: 'nike', 4: 'reebok', 5: 'vans'}\n",
27
+ "num_classes = len(idx_to_class)\n",
28
+ "\n",
29
+ "mean = [0.485, 0.456, 0.406]\n",
30
+ "std = [0.229, 0.224, 0.225]\n",
31
+ "test_transforms = transforms.Compose([transforms.Resize((224,224)),\n",
32
+ " transforms.ToTensor(),\n",
33
+ " transforms.Normalize(mean,std)\n",
34
+ " ])"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 6,
40
+ "id": "6a5105de",
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "def GetModel(model_name = 'efficientnet_b0',freeze = False):\n",
45
+ " model = timm.create_model(model_name = model_name,pretrained=True)\n",
46
+ " if freeze:\n",
47
+ " for parameter in model.parameters():\n",
48
+ " parameter.requires_grad = False\n",
49
+ " \n",
50
+ " in_features = model.classifier.in_features # 1792\n",
51
+ " \n",
52
+ " model.classifier = nn.Sequential(\n",
53
+ " nn.Linear(in_features, 100), \n",
54
+ " nn.BatchNorm1d(num_features=100),\n",
55
+ " nn.ReLU(),\n",
56
+ " nn.Dropout(),\n",
57
+ " nn.Linear(100, num_classes),\n",
58
+ " )\n",
59
+ " \n",
60
+ " return model"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": 7,
66
+ "id": "0b473e8c",
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "def LoadModel(model, model_path):\n",
71
+ " checkpoint = torch.load(model_path)\n",
72
+ " model.load_state_dict(checkpoint['state_dict'])\n",
73
+ " model.best_scores = checkpoint['best_stats']\n",
74
+ " return model"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 8,
80
+ "id": "40b3fd21",
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "model = LoadModel(GetModel(),\"snicker_model.pth\")"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 36,
90
+ "id": "bc2bc9b7",
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "def GetClassProbs(img):\n",
95
+ " with torch.no_grad():\n",
96
+ " model.eval()\n",
97
+ " #img = Image.open(img).convert(\"RGB\")\n",
98
+ " img = test_transforms(img)\n",
99
+ " img = img.unsqueeze(0)\n",
100
+ " output = model(img)\n",
101
+ " # remember softmax\n",
102
+ " probs = F.softmax(output,dim=1)\n",
103
+ " probs, indices = probs.topk(k=num_classes)\n",
104
+ " probs = probs[0].tolist()\n",
105
+ " indices = indices[0].tolist()\n",
106
+ " classes = [idx_to_class[index] for index in indices]\n",
107
+ " confidences = {classes[i]: round(probs[i],3) for i in range(num_classes)} \n",
108
+ "\n",
109
+ " return confidences"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": 38,
115
+ "id": "e070f2a7",
116
+ "metadata": {},
117
+ "outputs": [
118
+ {
119
+ "name": "stdout",
120
+ "output_type": "stream",
121
+ "text": [
122
+ "Running on local URL: http://127.0.0.1:7862\n",
123
+ "Running on public URL: https://67885f1c-1326-46d9.gradio.live\n",
124
+ "\n",
125
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces\n"
126
+ ]
127
+ },
128
+ {
129
+ "data": {
130
+ "text/html": [
131
+ "<div><iframe src=\"https://67885f1c-1326-46d9.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
132
+ ],
133
+ "text/plain": [
134
+ "<IPython.core.display.HTML object>"
135
+ ]
136
+ },
137
+ "metadata": {},
138
+ "output_type": "display_data"
139
+ },
140
+ {
141
+ "data": {
142
+ "text/plain": []
143
+ },
144
+ "execution_count": 38,
145
+ "metadata": {},
146
+ "output_type": "execute_result"
147
+ }
148
+ ],
149
+ "source": [
150
+ "import gradio as gr\n",
151
+ "examples = [\"samples/a.jpeg\",\"samples/c.jpeg\",\"samples/r.jpeg\"]\n",
152
+ "gr.Interface(fn=GetClassProbs, \n",
153
+ " inputs=gr.Image(type=\"pil\"),\n",
154
+ " outputs=gr.Label(num_top_classes=3),\n",
155
+ " examples=examples).launch(share=True)\n"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "id": "1f9eeb29",
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": []
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "id": "5fab64d9",
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": []
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": null,
177
+ "id": "cf1b0fb5",
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": []
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "id": "f88d10d0",
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": []
189
+ }
190
+ ],
191
+ "metadata": {
192
+ "kernelspec": {
193
+ "display_name": "Python 3 (ipykernel)",
194
+ "language": "python",
195
+ "name": "python3"
196
+ },
197
+ "language_info": {
198
+ "codemirror_mode": {
199
+ "name": "ipython",
200
+ "version": 3
201
+ },
202
+ "file_extension": ".py",
203
+ "mimetype": "text/x-python",
204
+ "name": "python",
205
+ "nbconvert_exporter": "python",
206
+ "pygments_lexer": "ipython3",
207
+ "version": "3.9.13"
208
+ },
209
+ "latex_envs": {
210
+ "LaTeX_envs_menu_present": true,
211
+ "autoclose": false,
212
+ "autocomplete": false,
213
+ "bibliofile": "biblio.bib",
214
+ "cite_by": "apalike",
215
+ "current_citInitial": 1,
216
+ "eqLabelWithNumbers": true,
217
+ "eqNumInitial": 1,
218
+ "hotkeys": {
219
+ "equation": "Ctrl-E",
220
+ "itemize": "Ctrl-I"
221
+ },
222
+ "labels_anchors": false,
223
+ "latex_user_defs": false,
224
+ "report_style_numbering": false,
225
+ "user_envs_cfg": false
226
+ },
227
+ "toc": {
228
+ "base_numbering": 1,
229
+ "nav_menu": {},
230
+ "number_sections": true,
231
+ "sideBar": true,
232
+ "skip_h1_title": false,
233
+ "title_cell": "Table of Contents",
234
+ "title_sidebar": "Contents",
235
+ "toc_cell": false,
236
+ "toc_position": {},
237
+ "toc_section_display": true,
238
+ "toc_window_display": false
239
+ }
240
+ },
241
+ "nbformat": 4,
242
+ "nbformat_minor": 5
243
+ }