Sandhya commited on
Commit
1afa467
·
1 Parent(s): 9e44b33

first commit

Browse files
app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ from model import create_model
5
+ from timeit import default_timer as Timer
6
+ class_names=["Pizza","Steak","Sushi"]
7
+ effnetb2,effnetb2_transform=create_model(num_class=3)
8
+ effnetb2.load_state_dict(torch.load(f="demos\foodvision_mini\effnetb2_pizza_steak_sushi (1).pt",map_location=torch.device('cpu')))
9
+
10
+
11
+ def predict(img):
12
+ start_time=timer()
13
+ img=effnetb2_transform(img).unsqueeze(0)
14
+ effnetb2.eval()
15
+ with torch.inference_mode():
16
+ predict_logit=effnetb2(img)
17
+ predict_prob=torch.softmax(predict_logit,dim=1)
18
+ predict_label=torch.argmax(predict_prob,dim=1)
19
+ pred_label_prob={class_names[i]:float(predict_prob[0][i]) for i in range(len(class_names))}
20
+ end_time=timer()
21
+ pred_time=round(end_time-start_time,4)
22
+ return pred_label_prob,pred_time
23
+
24
+ title="Food Vision Mini 🍕🥩🍣"
25
+ description="An EfficientNetB2 feature extractor computer vision model to classify images of food as pizza, steak or sushi."
26
+ example_list=[['examples/'+example] for example in os.listdir("examples")]
27
+ demo=gr.Interface(fn=predict,inputs=gr.Image(type="pil"),outputs=[gr.Label(num_top_classes=3,label="Prediction"),gr.Number(label="Prediction time")],examples=example_list,title=title,description=description)
28
+ demo.launch()
effnetb2_pizza_steak_sushi (1).pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee17107d6aa14ae47b961ab7b51e0a0a9c9ba04b5e35bf2fa546cca1d15677b5
3
+ size 31284026
examples/289822.jpg ADDED
examples/2901001.jpg ADDED
examples/796922.jpg ADDED
model.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision import transforms
4
+ from torch import nn
5
+ def create_model(num_class:int=3,seed:int=42):
6
+ weights=torchvision.models.EfficientNet_B2_Weights.DEFAULT
7
+ transform=weights.transforms()
8
+ model=torchvision.models.efficientnet_b2(weights=weights)
9
+ for params in model.parameters():
10
+ params.requires_grad=False
11
+ torch.manual_seed(seed)
12
+ model.classifier=nn.Sequential(nn.Dropout(p=0.3, inplace=True),nn.Linear(in_features=1408, out_features=num_class, bias=True))
13
+ return model,transform
requirements.txt ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ aiobotocore==2.20.0
3
+ aiofiles==24.1.0
4
+ aiohappyeyeballs==2.4.6
5
+ aiohttp==3.11.13
6
+ aioitertools==0.12.0
7
+ aiosignal==1.3.2
8
+ altair==5.4.1
9
+ anndata==0.11.3
10
+ annotated-types==0.7.0
11
+ anyio==4.6.2.post1
12
+ argon2-cffi==23.1.0
13
+ argon2-cffi-bindings==21.2.0
14
+ array_api_compat==1.11
15
+ arrow==1.3.0
16
+ asciitree==0.3.3
17
+ asttokens==2.4.1
18
+ astunparse==1.6.3
19
+ async-lru==2.0.4
20
+ attrs==24.2.0
21
+ babel==2.16.0
22
+ backports.tarfile==1.2.0
23
+ beautifulsoup4==4.12.3
24
+ bleach==6.1.0
25
+ blinker==1.8.2
26
+ blis==1.2.0
27
+ botocore==1.36.23
28
+ cachetools==5.5.0
29
+ catalogue==2.0.10
30
+ certifi==2024.8.30
31
+ cffi==1.17.1
32
+ charset-normalizer==3.4.0
33
+ click==8.1.7
34
+ cloudpathlib==0.20.0
35
+ cloudpickle==3.1.1
36
+ colorama==0.4.6
37
+ colorcet==3.1.0
38
+ coloredlogs==15.0.1
39
+ comm==0.2.2
40
+ confection==0.1.5
41
+ contourpy==1.3.1
42
+ cycler==0.12.1
43
+ cymem==2.0.11
44
+ dask==2024.11.2
45
+ dask-expr==1.1.19
46
+ dask-image==2024.5.3
47
+ dataclasses-json==0.6.7
48
+ datashader==0.17.0
49
+ debugpy==1.8.7
50
+ decorator==5.1.1
51
+ defusedxml==0.7.1
52
+ Deprecated==1.2.18
53
+ distributed==2024.11.2
54
+ distro==1.9.0
55
+ dlib==19.24.6
56
+ docutils==0.21.2
57
+ en_core_web_md @ https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.8.0/en_core_web_md-3.8.0-py3-none-any.whl#sha256=5e6329fe3fecedb1d1a02c3ea2172ee0fede6cea6e4aefb6a02d832dba78a310
58
+ executing==2.1.0
59
+ face-recognition==1.3.0
60
+ face-recognition-models==0.3.0
61
+ fastapi==0.115.12
62
+ fasteners==0.19
63
+ fastjsonschema==2.20.0
64
+ ffmpy==0.5.0
65
+ filelock==3.18.0
66
+ Flask==3.0.3
67
+ Flask-Cors==5.0.0
68
+ flatbuffers==25.1.24
69
+ fonttools==4.55.3
70
+ fqdn==1.5.1
71
+ frozenlist==1.5.0
72
+ fsspec==2025.2.0
73
+ gast==0.6.0
74
+ geopandas==1.0.1
75
+ gitdb==4.0.11
76
+ GitPython==3.1.43
77
+ google-ai-generativelanguage==0.6.15
78
+ google-api-core==2.24.1
79
+ google-api-python-client==2.162.0
80
+ google-auth==2.38.0
81
+ google-auth-httplib2==0.2.0
82
+ google-generativeai==0.8.4
83
+ google-pasta==0.2.0
84
+ googleapis-common-protos==1.68.0
85
+ gradio==5.26.0
86
+ gradio_client==1.9.0
87
+ greenlet==3.1.1
88
+ groovy==0.1.2
89
+ grpcio==1.71.0rc2
90
+ grpcio-status==1.71.0rc2
91
+ h11==0.14.0
92
+ h5py==3.12.1
93
+ httpcore==1.0.6
94
+ httplib2==0.22.0
95
+ httpx==0.27.2
96
+ httpx-sse==0.4.0
97
+ huggingface-hub==0.30.1
98
+ humanfriendly==10.0
99
+ id==1.5.0
100
+ idna==3.10
101
+ imageio==2.37.0
102
+ imbalanced-learn==0.13.0
103
+ imblearn==0.0
104
+ importlib_metadata==8.6.1
105
+ ipykernel==6.29.5
106
+ ipython==8.28.0
107
+ ipywidgets==8.1.5
108
+ isoduration==20.11.0
109
+ itsdangerous==2.2.0
110
+ jaraco.classes==3.4.0
111
+ jaraco.context==6.0.1
112
+ jaraco.functools==4.1.0
113
+ jax==0.5.0
114
+ jaxlib==0.5.0
115
+ jedi==0.19.1
116
+ Jinja2==3.1.4
117
+ jiter==0.9.0
118
+ jmespath==1.0.1
119
+ joblib==1.4.2
120
+ json5==0.9.25
121
+ jsonpatch==1.33
122
+ jsonpointer==3.0.0
123
+ jsonschema==4.23.0
124
+ jsonschema-specifications==2024.10.1
125
+ jupyter==1.1.1
126
+ jupyter-console==6.6.3
127
+ jupyter-events==0.10.0
128
+ jupyter-lsp==2.2.5
129
+ jupyter_client==8.6.3
130
+ jupyter_core==5.7.2
131
+ jupyter_server==2.14.2
132
+ jupyter_server_terminals==0.5.3
133
+ jupyterlab==4.2.5
134
+ jupyterlab_pygments==0.3.0
135
+ jupyterlab_server==2.27.3
136
+ jupyterlab_widgets==3.0.13
137
+ keras==3.8.0
138
+ keyring==25.6.0
139
+ kiwisolver==1.4.8
140
+ langchain==0.3.23
141
+ langchain-community==0.3.21
142
+ langchain-core==0.3.51
143
+ langchain-openai==0.3.12
144
+ langchain-text-splitters==0.3.8
145
+ langcodes==3.5.0
146
+ langgraph==0.3.28
147
+ langgraph-checkpoint==2.0.24
148
+ langgraph-prebuilt==0.1.8
149
+ langgraph-sdk==0.1.61
150
+ langsmith==0.3.30
151
+ language_data==1.3.0
152
+ lazy_loader==0.4
153
+ libclang==18.1.1
154
+ lightning-utilities==0.14.2
155
+ llvmlite==0.44.0
156
+ locket==1.0.0
157
+ marisa-trie==1.2.1
158
+ Markdown==3.7
159
+ markdown-it-py==3.0.0
160
+ MarkupSafe==3.0.1
161
+ marshmallow==3.26.1
162
+ matplot==0.1.9
163
+ matplotlib==3.10.0
164
+ matplotlib-inline==0.1.7
165
+ mdurl==0.1.2
166
+ mediapipe==0.10.21
167
+ mistune==3.0.2
168
+ ml-dtypes==0.4.1
169
+ monai==1.4.0
170
+ more-itertools==10.7.0
171
+ MouseInfo==0.1.3
172
+ mpmath==1.3.0
173
+ msgpack==1.1.0
174
+ multidict==6.1.0
175
+ multipledispatch==1.0.0
176
+ multiscale_spatial_image==2.0.2
177
+ murmurhash==1.0.12
178
+ mypy-extensions==1.0.0
179
+ mysql-connector-python==9.1.0
180
+ namex==0.0.8
181
+ narwhals==1.9.4
182
+ natsort==8.4.0
183
+ nbclient==0.10.0
184
+ nbconvert==7.16.4
185
+ nbformat==5.10.4
186
+ nest-asyncio==1.6.0
187
+ networkx==3.4.2
188
+ nh3==0.2.21
189
+ nltk==3.9.1
190
+ notebook==7.2.2
191
+ notebook_shim==0.2.4
192
+ numba==0.61.0
193
+ numcodecs==0.15.1
194
+ numpy==1.26.4
195
+ ome-zarr==0.10.3
196
+ onnx==1.17.0
197
+ onnxruntime==1.21.1
198
+ onnxscript==0.2.4
199
+ openai==1.72.0
200
+ opencv-contrib-python==4.11.0.86
201
+ opencv-python==4.11.0.86
202
+ opt_einsum==3.4.0
203
+ optree==0.14.0
204
+ orjson==3.10.16
205
+ ormsgpack==1.9.1
206
+ overrides==7.7.0
207
+ packaging==24.1
208
+ pandas==2.2.3
209
+ pandocfilters==1.5.1
210
+ param==2.2.0
211
+ parso==0.8.4
212
+ partd==1.4.2
213
+ pillow==10.4.0
214
+ PIMS==0.7
215
+ platformdirs==4.3.6
216
+ pooch==1.8.2
217
+ preshed==3.0.9
218
+ prometheus_client==0.21.0
219
+ prompt_toolkit==3.0.48
220
+ propcache==0.3.0
221
+ proto-plus==1.26.0
222
+ protobuf==5.29.3
223
+ psutil==6.1.0
224
+ pure_eval==0.2.3
225
+ pyarrow==17.0.0
226
+ pyasn1==0.6.1
227
+ pyasn1_modules==0.4.1
228
+ PyAutoGUI==0.9.54
229
+ pycparser==2.22
230
+ pyct==0.5.0
231
+ pydantic==2.10.6
232
+ pydantic-settings==2.8.1
233
+ pydantic_core==2.27.2
234
+ pydeck==0.9.1
235
+ pydub==0.25.1
236
+ PyGetWindow==0.0.9
237
+ Pygments==2.18.0
238
+ pyloco==0.0.139
239
+ PyMsgBox==1.0.9
240
+ pynput==1.7.7
241
+ pyogrio==0.10.0
242
+ pyopencl==2025.1
243
+ pyparsing==3.2.0
244
+ pypdf==5.4.0
245
+ pyperclip==1.9.0
246
+ pyproj==3.7.1
247
+ pyreadline3==3.5.4
248
+ PyRect==0.2.0
249
+ PyScreeze==1.0.1
250
+ pytesseract==0.3.13
251
+ python-dateutil==2.9.0.post0
252
+ python-dotenv==1.1.0
253
+ python-json-logger==2.0.7
254
+ python-multipart==0.0.20
255
+ pytools==2025.1.2
256
+ pytweening==1.2.0
257
+ pytz==2024.2
258
+ pywin32==308
259
+ pywin32-ctypes==0.2.3
260
+ pywinpty==2.0.14
261
+ PyYAML==6.0.2
262
+ pyzmq==26.2.0
263
+ readme_renderer==44.0
264
+ redis==5.1.1
265
+ referencing==0.35.1
266
+ regex==2024.9.11
267
+ requests==2.32.3
268
+ requests-toolbelt==1.0.0
269
+ rfc3339-validator==0.1.4
270
+ rfc3986==2.0.0
271
+ rfc3986-validator==0.1.1
272
+ rich==13.9.2
273
+ rpds-py==0.20.0
274
+ rsa==4.9
275
+ ruff==0.11.7
276
+ s3fs==2025.2.0
277
+ safehttpx==0.1.6
278
+ safetensors==0.5.3
279
+ scikit-image==0.25.2
280
+ scikit-learn==1.5.2
281
+ scipy==1.14.1
282
+ seaborn==0.13.2
283
+ semantic-version==2.10.0
284
+ Send2Trash==1.8.3
285
+ sentencepiece==0.2.0
286
+ shapely==2.0.7
287
+ shellingham==1.5.4
288
+ SimpleWebSocketServer==0.1.2
289
+ six==1.16.0
290
+ sklearn-compat==0.1.3
291
+ slicerator==1.1.0
292
+ smart-open==7.1.0
293
+ smmap==5.0.1
294
+ sniffio==1.3.1
295
+ some_library==0.0.1
296
+ sortedcontainers==2.4.0
297
+ sounddevice==0.5.1
298
+ soupsieve==2.6
299
+ spacy==3.8.4
300
+ spacy-legacy==3.0.12
301
+ spacy-loggers==1.0.5
302
+ spatial_image==1.2.1
303
+ spatialdata==0.3.0
304
+ spotipy==2.24.0
305
+ SQLAlchemy==2.0.40
306
+ srsly==2.5.1
307
+ stack-data==0.6.3
308
+ starlette==0.46.2
309
+ streamlit==1.39.0
310
+ sympy==1.13.1
311
+ tblib==3.0.0
312
+ tenacity==9.0.0
313
+ tensorboard==2.18.0
314
+ tensorboard-data-server==0.7.2
315
+ tensorflow==2.18.0
316
+ tensorflow-io-gcs-filesystem==0.31.0
317
+ tensorflow_intel==2.18.0
318
+ termcolor==2.5.0
319
+ terminado==0.18.1
320
+ thinc==8.3.4
321
+ threadpoolctl==3.5.0
322
+ tifffile==2025.2.18
323
+ tiktoken==0.9.0
324
+ tinycss2==1.3.0
325
+ tokenizers==0.21.1
326
+ toml==0.10.2
327
+ tomlkit==0.13.2
328
+ toolz==1.0.0
329
+ torch==2.6.0
330
+ torchinfo==1.8.0
331
+ torchmetrics==1.7.0
332
+ torchvision==0.21.0
333
+ tornado==6.4.1
334
+ tqdm==4.66.5
335
+ traitlets==5.14.3
336
+ transformers==4.50.3
337
+ twine==6.1.0
338
+ typer==0.15.2
339
+ types-python-dateutil==2.9.0.20241003
340
+ typing==3.7.4.3
341
+ typing-inspect==0.9.0
342
+ typing_extensions==4.12.2
343
+ tzdata==2024.2
344
+ uri-template==1.3.0
345
+ uritemplate==4.1.1
346
+ urllib3==2.2.3
347
+ ushlex==0.99.1
348
+ uvicorn==0.34.2
349
+ wasabi==1.1.3
350
+ watchdog==5.0.3
351
+ wcwidth==0.2.13
352
+ weasel==0.4.1
353
+ webcolors==24.8.0
354
+ webencodings==0.5.1
355
+ websocket-client==1.8.0
356
+ websockets==15.0.1
357
+ Werkzeug==3.1.0
358
+ widgetsnbextension==4.0.13
359
+ wrapt==1.17.2
360
+ xarray==2024.11.0
361
+ xarray-dataclasses==1.9.1
362
+ xarray-schema==0.0.3
363
+ xarray-spatial==0.4.0
364
+ xxhash==3.5.0
365
+ yarl==1.18.3
366
+ zarr==2.18.4
367
+ zict==3.0.0
368
+ zipp==3.21.0
369
+ zstandard==0.23.0