amar6de2 commited on
Commit
a54b1d7
·
verified ·
1 Parent(s): 4c75799

Upload 5 files

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. app.py +76 -0
  3. class_names.txt +121 -0
  4. examples/chai.jpg +3 -0
  5. model.py +35 -0
  6. vit_epoch_2.pth +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/chai.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+ from model import create_vit_model
8
+ from timeit import default_timer as timer
9
+ from typing import Tuple, Dict
10
+
11
+ # Setup class names
12
+ with open("class_names.txt", "r") as f:
13
+ class_names = [food_name.strip() for food_name in f.readlines()]
14
+
15
+ ### 2. Model and transforms preparation ###
16
+
17
+ # Create model
18
+ vit, vit_transforms = create_vit_model(num_classes=121)
19
+
20
+ # Load saved weights
21
+ vit.load_state_dict(
22
+ torch.load(
23
+ f="vit_epoch_2.pth",
24
+ map_location=torch.device("cpu"),
25
+ )
26
+ )
27
+
28
+ ### 3. Predict function ###
29
+
30
+ def predict(img) -> Tuple[Dict, float]:
31
+ """Transforms and performs a prediction on img and returns prediction and time taken."""
32
+ start_time = timer()
33
+
34
+ # ✅ Ensure the image is in PIL format
35
+ if isinstance(img, np.ndarray):
36
+ img = Image.fromarray(img.astype("uint8"), mode="RGB")
37
+
38
+ # Transform and add batch dimension
39
+ img = vit_transforms(img).unsqueeze(0)
40
+
41
+ # Inference
42
+ vit.eval()
43
+ with torch.inference_mode():
44
+ pred_probs = torch.softmax(vit(img), dim=1)
45
+
46
+ # Create output dictionary
47
+ pred_labels_and_probs = {
48
+ class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))
49
+ }
50
+
51
+ pred_time = round(timer() - start_time, 5)
52
+ return pred_labels_and_probs, pred_time
53
+
54
+ ### 4. Gradio app ###
55
+
56
+ title = "VisionBite 🍔👁"
57
+ description = "A ViT feature extractor computer vision model to classify images of food into 121 categories."
58
+ article = "The model has been trained on the Food121 dataset using ViT Base 16."
59
+
60
+ # ✅ Sort examples for consistent UI (optional)
61
+ example_list = [["examples/" + example] for example in sorted(os.listdir("examples")) if example.endswith((".jpg", ".png", ".jpeg"))]
62
+
63
+ demo = gr.Interface(
64
+ fn=predict,
65
+ inputs=gr.Image(type="pil"),
66
+ outputs=[
67
+ gr.Label(num_top_classes=5, label="Predictions"),
68
+ gr.Number(label="Prediction time (s)"),
69
+ ],
70
+ examples=example_list,
71
+ title=title,
72
+ description=description,
73
+ article=article,
74
+ )
75
+
76
+ demo.launch()
class_names.txt ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apple_pie
2
+ baby_back_ribs
3
+ baklava
4
+ beef_carpaccio
5
+ beef_tartare
6
+ beet_salad
7
+ beignets
8
+ bibimbap
9
+ biryani
10
+ bread_pudding
11
+ breakfast_burrito
12
+ bruschetta
13
+ caesar_salad
14
+ cannoli
15
+ caprese_salad
16
+ carrot_cake
17
+ ceviche
18
+ chai
19
+ chapati
20
+ cheese_plate
21
+ cheesecake
22
+ chicken_curry
23
+ chicken_quesadilla
24
+ chicken_wings
25
+ chocolate_cake
26
+ chocolate_mousse
27
+ chole_bhature
28
+ churros
29
+ clam_chowder
30
+ club_sandwich
31
+ crab_cakes
32
+ creme_brulee
33
+ croque_madame
34
+ cup_cakes
35
+ dabeli
36
+ dal
37
+ deviled_eggs
38
+ dhokla
39
+ donuts
40
+ dosa
41
+ dumplings
42
+ edamame
43
+ eggs_benedict
44
+ escargots
45
+ falafel
46
+ filet_mignon
47
+ fish_and_chips
48
+ foie_gras
49
+ french_fries
50
+ french_onion_soup
51
+ french_toast
52
+ fried_calamari
53
+ fried_rice
54
+ frozen_yogurt
55
+ garlic_bread
56
+ gnocchi
57
+ greek_salad
58
+ grilled_cheese_sandwich
59
+ grilled_salmon
60
+ guacamole
61
+ gyoza
62
+ hamburger
63
+ hot_and_sour_soup
64
+ hot_dog
65
+ huevos_rancheros
66
+ hummus
67
+ ice_cream
68
+ idli
69
+ jalebi
70
+ kathi_rolls
71
+ kofta
72
+ kulfi
73
+ lasagna
74
+ lobster_bisque
75
+ lobster_roll_sandwich
76
+ macaroni_and_cheese
77
+ macarons
78
+ miso_soup
79
+ momos
80
+ mussels
81
+ naan
82
+ nachos
83
+ omelette
84
+ onion_rings
85
+ oysters
86
+ pad_thai
87
+ paella
88
+ pakoda
89
+ pancakes
90
+ pani_puri
91
+ panna_cotta
92
+ panner_butter_masala
93
+ pav_bhaji
94
+ peking_duck
95
+ pho
96
+ pizza
97
+ pork_chop
98
+ poutine
99
+ prime_rib
100
+ pulled_pork_sandwich
101
+ ramen
102
+ ravioli
103
+ red_velvet_cake
104
+ risotto
105
+ samosa
106
+ sashimi
107
+ scallops
108
+ seaweed_salad
109
+ shrimp_and_grits
110
+ spaghetti_bolognese
111
+ spaghetti_carbonara
112
+ spring_rolls
113
+ steak
114
+ strawberry_shortcake
115
+ sushi
116
+ tacos
117
+ takoyaki
118
+ tiramisu
119
+ tuna_tartare
120
+ vadapav
121
+ waffles
examples/chai.jpg ADDED

Git LFS Details

  • SHA256: a15ec79cfead22408d436d2f492e29e548e0cc848fa162c4fa2b4eb52c098534
  • Pointer size: 131 Bytes
  • Size of remote file: 117 kB
model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+
4
+ from torch import nn
5
+ def create_vit_model(num_classes:int=3,
6
+ seed:int=42):
7
+ """Creates a ViT-B/16 feature extractor model and transforms.
8
+
9
+ Args:
10
+ num_classes (int, optional): number of target classes. Defaults to 3.
11
+ seed (int, optional): random seed value for output layer. Defaults to 42.
12
+
13
+ Returns:
14
+ model (torch.nn.Module): ViT-B/16 feature extractor model.
15
+ transforms (torchvision.transforms): ViT-B/16 image transforms.
16
+ """
17
+ # Create ViT_B_16 pretrained weights, transforms and model
18
+ weights = torchvision.models.ViT_B_16_Weights.DEFAULT
19
+ transforms = weights.transforms()
20
+ model = torchvision.models.vit_b_16(weights=weights)
21
+
22
+ # Freeze all layers in model
23
+ for param in model.parameters():
24
+ param.requires_grad = False
25
+
26
+ # Change classifier head to suit our needs (this will be trainable)
27
+ torch.manual_seed(seed)
28
+ model.heads = nn.Sequential(
29
+ nn.LayerNorm(768),
30
+ nn.Dropout(0.2), # Try 0.1 or 0.2
31
+ nn.Linear(768, 121)
32
+ )
33
+
34
+
35
+ return model, transforms
vit_epoch_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e2feb708f66db4d26e017d955e6e8f8e64842e9a71e67e81cd3f4dc3f956eff
3
+ size 343634614