hari31416 commited on
Commit
acc38f0
·
1 Parent(s): d71a61c

Added files

Browse files
Files changed (6) hide show
  1. app.py +105 -0
  2. examples/pizza.jpg +0 -0
  3. examples/samosa.jpg +0 -0
  4. food101.pt +3 -0
  5. labels.txt +101 -0
  6. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms, models
4
+ from PIL import Image
5
+ from torch import nn
6
+
7
+ model_name = "b0"
8
+ if model_name == "b4":
9
+ IMAGE_RESIZE_SHAPE = 384
10
+ IMAGE_FINAL_SHAPE = 380
11
+ BATCH_SIZE = 32
12
+ FEATURE_SHAPE = 1792
13
+
14
+ if model_name == "b0":
15
+ IMAGE_RESIZE_SHAPE = 256
16
+ IMAGE_FINAL_SHAPE = 224
17
+ BATCH_SIZE = 32
18
+ FEATURE_SHAPE = 1280
19
+
20
+
21
+ def load_labels(label_text_path):
22
+ with open(label_text_path, "r") as f:
23
+ lables = [line.strip() for line in f.readlines()]
24
+ label_dict = {i: lables[i] for i in range(len(lables))}
25
+ return label_dict
26
+
27
+
28
+ label_dict = load_labels("labels.txt")
29
+
30
+ # Load PyTorch model
31
+ model_params = torch.load("food101.pt", map_location=torch.device("cpu"))
32
+ if model_name == "b4":
33
+ model = models.efficientnet_b4()
34
+ if model_name == "b0":
35
+ model = models.efficientnet_b0()
36
+
37
+ model.eval()
38
+ for params in model.parameters():
39
+ params.requires_grad = False
40
+ model.classifier[1] = nn.Linear(in_features=FEATURE_SHAPE, out_features=101)
41
+ model.load_state_dict(model_params)
42
+
43
+ # Define image transformation
44
+ normalize = transforms.Normalize(
45
+ mean=[0.485, 0.456, 0.406],
46
+ std=[0.229, 0.224, 0.225],
47
+ )
48
+
49
+ transform = transforms.Compose(
50
+ [
51
+ transforms.Resize(IMAGE_RESIZE_SHAPE),
52
+ transforms.CenterCrop(IMAGE_FINAL_SHAPE),
53
+ transforms.ToTensor(),
54
+ normalize,
55
+ ]
56
+ )
57
+
58
+
59
+ # Define prediction function
60
+ def predict_image_class(image):
61
+ # Load image
62
+ image = Image.fromarray(image.astype("uint8"), "RGB")
63
+
64
+ # Apply transformation
65
+ transformed_image = transform(image)
66
+
67
+ # Add batch dimension
68
+ transformed_image = transformed_image.unsqueeze(0)
69
+
70
+ # Disable gradient calculation
71
+ with torch.no_grad():
72
+ # Make prediction
73
+ output = model(transformed_image)
74
+ _, indices = torch.sort(output, descending=True)
75
+
76
+ percentage = torch.nn.functional.softmax(output, dim=1)[0]
77
+ # create a dictionary of top 10 classes
78
+ top_10 = {}
79
+ for idx in indices[0][:10]:
80
+ top_10[label_dict[idx.item()]] = percentage[idx].item()
81
+ return top_10
82
+
83
+
84
+ # Define Gradio interface
85
+ description = "This is a demo of EfficientNet trained on Food101 dataset.\
86
+ Upload an image of food and it will predict the class of the food."
87
+ inputs = gr.inputs.Image()
88
+ outputs = gr.outputs.Label(num_top_classes=10)
89
+ gradio_app = gr.Interface(
90
+ fn=predict_image_class,
91
+ inputs=inputs,
92
+ outputs=outputs,
93
+ title="FoodVision",
94
+ description=description,
95
+ theme="snehilsanyal/scikit-learn",
96
+ examples=[
97
+ ["examples/pizza.jpg"],
98
+ ["examples/samosa.jpg"],
99
+ ],
100
+ )
101
+
102
+ # Run Gradio app
103
+ gradio_app.launch(
104
+ server_port=7860,
105
+ )
examples/pizza.jpg ADDED
examples/samosa.jpg ADDED
food101.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5bda7bcad0d9284628d7d3a22e9d2179b594f8dcab2c1d2c1e3456779d571b6
3
+ size 16844913
labels.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apple pie
2
+ Baby back ribs
3
+ Baklava
4
+ Beef carpaccio
5
+ Beef tartare
6
+ Beet salad
7
+ Beignets
8
+ Bibimbap
9
+ Bread pudding
10
+ Breakfast burrito
11
+ Bruschetta
12
+ Caesar salad
13
+ Cannoli
14
+ Caprese salad
15
+ Carrot cake
16
+ Ceviche
17
+ Cheesecake
18
+ Cheese plate
19
+ Chicken curry
20
+ Chicken quesadilla
21
+ Chicken wings
22
+ Chocolate cake
23
+ Chocolate mousse
24
+ Churros
25
+ Clam chowder
26
+ Club sandwich
27
+ Crab cakes
28
+ Creme brulee
29
+ Croque madame
30
+ Cup cakes
31
+ Deviled eggs
32
+ Donuts
33
+ Dumplings
34
+ Edamame
35
+ Eggs benedict
36
+ Escargots
37
+ Falafel
38
+ Filet mignon
39
+ Fish and chips
40
+ Foie gras
41
+ French fries
42
+ French onion soup
43
+ French toast
44
+ Fried calamari
45
+ Fried rice
46
+ Frozen yogurt
47
+ Garlic bread
48
+ Gnocchi
49
+ Greek salad
50
+ Grilled cheese sandwich
51
+ Grilled salmon
52
+ Guacamole
53
+ Gyoza
54
+ Hamburger
55
+ Hot and sour soup
56
+ Hot dog
57
+ Huevos rancheros
58
+ Hummus
59
+ Ice cream
60
+ Lasagna
61
+ Lobster bisque
62
+ Lobster roll sandwich
63
+ Macaroni and cheese
64
+ Macarons
65
+ Miso soup
66
+ Mussels
67
+ Nachos
68
+ Omelette
69
+ Onion rings
70
+ Oysters
71
+ Pad thai
72
+ Paella
73
+ Pancakes
74
+ Panna cotta
75
+ Peking duck
76
+ Pho
77
+ Pizza
78
+ Pork chop
79
+ Poutine
80
+ Prime rib
81
+ Pulled pork sandwich
82
+ Ramen
83
+ Ravioli
84
+ Red velvet cake
85
+ Risotto
86
+ Samosa
87
+ Sashimi
88
+ Scallops
89
+ Seaweed salad
90
+ Shrimp and grits
91
+ Spaghetti bolognese
92
+ Spaghetti carbonara
93
+ Spring rolls
94
+ Steak
95
+ Strawberry shortcake
96
+ Sushi
97
+ Tacos
98
+ Takoyaki
99
+ Tiramisu
100
+ Tuna tartare
101
+ Waffles
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ pytorch
3
+ torchvision
4
+ pillow