File size: 2,207 Bytes
845f220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
from torchvision.models import EfficientNet_B2_Weights
import gradio as gr
from model import EffnetB2
from timeit import default_timer as timer
from typing import Tuple
import os


device = "cpu"

# laoding our model
model = EffnetB2()
checkpoint = torch.load(
    "effnetB2_2025-11-29_epoch4.pt",
    map_location=device,
)
model.load_state_dict(checkpoint)
model.to(device)

class_names = ["pizza", "steak", "sushi"]


def predict(img) -> Tuple[dict, float]:
    """_summary_
    Takes an image and make predictions
    Args:
        img (_type_): An Image

    Returns:
        Tuple[dict, float]: a dict for the confidence of each class, and float for the inference time
    """

    # Transform the image to work with effnetB2
    transform = (
        EfficientNet_B2_Weights.DEFAULT.transforms()
    )  # getting the model transforms
    transformed_img = transform(img).unsqueeze(0).to(device)

    # Put model into eval mode and make predictions
    start = timer()  # start timer
    model.eval()

    with torch.inference_mode():
        logits = model(transformed_img)
        pred_probs = torch.softmax(logits, dim=1)
        pred_labels = torch.argmax(pred_probs, dim=1)

    # Creating a prediction label and a preiction probability dict
    pred_dict = {
        k: v for k, v in zip(class_names, pred_probs.squeeze(0).cpu().tolist())
    }
    end = timer()
    pred_time = round(end - start, 4)
    return (pred_dict, pred_time)


# Creating a list of exmaple images for Gradio Demo
example_list = [
    ["examples/" + example]
    for example in os.listdir("examples")
]

title = "Food Vision Mini 🍕🥩🍣"
description = "An efficientNetB2 feature extractor computr vision model to classify images as pizza, steak and sushi."
article = "Created at [09.Pytorch Model Deployement.](https://www.learnpytorch.io/09_pytorch_model_deployment/)"

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Label(num_top_classes=3, label="Predictions"),
        gr.Number(label="Prediction Time(s)"),
    ],
    examples=example_list,
    title=title,
    description=description,
    article=article,
)

demo.launch(debug=False, share=True)