r-boudali's picture
Add application files
845f220
import torch
from torchvision.models import EfficientNet_B2_Weights
import gradio as gr
from model import EffnetB2
from timeit import default_timer as timer
from typing import Tuple
import os
device = "cpu"
# laoding our model
model = EffnetB2()
checkpoint = torch.load(
"effnetB2_2025-11-29_epoch4.pt",
map_location=device,
)
model.load_state_dict(checkpoint)
model.to(device)
class_names = ["pizza", "steak", "sushi"]
def predict(img) -> Tuple[dict, float]:
"""_summary_
Takes an image and make predictions
Args:
img (_type_): An Image
Returns:
Tuple[dict, float]: a dict for the confidence of each class, and float for the inference time
"""
# Transform the image to work with effnetB2
transform = (
EfficientNet_B2_Weights.DEFAULT.transforms()
) # getting the model transforms
transformed_img = transform(img).unsqueeze(0).to(device)
# Put model into eval mode and make predictions
start = timer() # start timer
model.eval()
with torch.inference_mode():
logits = model(transformed_img)
pred_probs = torch.softmax(logits, dim=1)
pred_labels = torch.argmax(pred_probs, dim=1)
# Creating a prediction label and a preiction probability dict
pred_dict = {
k: v for k, v in zip(class_names, pred_probs.squeeze(0).cpu().tolist())
}
end = timer()
pred_time = round(end - start, 4)
return (pred_dict, pred_time)
# Creating a list of exmaple images for Gradio Demo
example_list = [
["examples/" + example]
for example in os.listdir("examples")
]
title = "Food Vision Mini 🍕🥩🍣"
description = "An efficientNetB2 feature extractor computr vision model to classify images as pizza, steak and sushi."
article = "Created at [09.Pytorch Model Deployement.](https://www.learnpytorch.io/09_pytorch_model_deployment/)"
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=3, label="Predictions"),
gr.Number(label="Prediction Time(s)"),
],
examples=example_list,
title=title,
description=description,
article=article,
)
demo.launch(debug=False, share=True)