File size: 2,839 Bytes
acc38f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c7199b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acc38f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
import torch
from torchvision import transforms, models
from PIL import Image
from torch import nn

model_name = "b0"
if model_name == "b4":
    IMAGE_RESIZE_SHAPE = 384
    IMAGE_FINAL_SHAPE = 380
    BATCH_SIZE = 32
    FEATURE_SHAPE = 1792

if model_name == "b0":
    IMAGE_RESIZE_SHAPE = 256
    IMAGE_FINAL_SHAPE = 224
    BATCH_SIZE = 32
    FEATURE_SHAPE = 1280


def load_labels(label_text_path):
    with open(label_text_path, "r") as f:
        lables = [line.strip() for line in f.readlines()]
    label_dict = {i: lables[i] for i in range(len(lables))}
    return label_dict


label_dict = load_labels("labels.txt")

# Load PyTorch model
model_params = torch.load("food101.pt", map_location=torch.device("cpu"))
if model_name == "b4":
    model = models.efficientnet_b4()
if model_name == "b0":
    model = models.efficientnet_b0()

model.eval()
for params in model.parameters():
    params.requires_grad = False
model.classifier[1] = nn.Linear(in_features=FEATURE_SHAPE, out_features=101)
model.load_state_dict(model_params)

# Define image transformation
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225],
)

transform = transforms.Compose(
    [
        transforms.Resize(IMAGE_RESIZE_SHAPE),
        transforms.CenterCrop(IMAGE_FINAL_SHAPE),
        transforms.ToTensor(),
        normalize,
    ]
)


# Define prediction function
def predict_image_class(image):
    # Load image
    image = Image.fromarray(image.astype("uint8"), "RGB")

    # Apply transformation
    transformed_image = transform(image)

    # Add batch dimension
    transformed_image = transformed_image.unsqueeze(0)

    # Disable gradient calculation
    with torch.no_grad():
        # Make prediction
        output = model(transformed_image)
    _, indices = torch.sort(output, descending=True)

    percentage = torch.nn.functional.softmax(output, dim=1)[0]
    # create a dictionary of top 10 classes
    top_10 = {}
    for idx in indices[0][:10]:
        top_10[label_dict[idx.item()]] = percentage[idx].item()
    return top_10

def main():
    # Define Gradio interface
    description = "This is a demo of EfficientNet trained on Food101 dataset.\
        Upload an image of food and it will predict the class of the food."
    inputs = gr.Image()
    outputs = gr.Label(num_top_classes=10, label="Prediction")
    gradio_app = gr.Interface(
        fn=predict_image_class,
        inputs=inputs,
        outputs=outputs,
        title="FoodVision",
        description=description,
        theme="snehilsanyal/scikit-learn",
        examples=[
            ["examples/pizza.jpg"],
            ["examples/samosa.jpg"],
        ],
    )
    
    gradio_app.queue().launch(server_name="0.0.0.0", server_port=7860)


if __name__ == "__main__":
    # Run Gradio app
    main()