Spaces:
Sleeping
Sleeping
File size: 2,207 Bytes
845f220 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
from torchvision.models import EfficientNet_B2_Weights
import gradio as gr
from model import EffnetB2
from timeit import default_timer as timer
from typing import Tuple
import os
device = "cpu"
# laoding our model
model = EffnetB2()
checkpoint = torch.load(
"effnetB2_2025-11-29_epoch4.pt",
map_location=device,
)
model.load_state_dict(checkpoint)
model.to(device)
class_names = ["pizza", "steak", "sushi"]
def predict(img) -> Tuple[dict, float]:
"""_summary_
Takes an image and make predictions
Args:
img (_type_): An Image
Returns:
Tuple[dict, float]: a dict for the confidence of each class, and float for the inference time
"""
# Transform the image to work with effnetB2
transform = (
EfficientNet_B2_Weights.DEFAULT.transforms()
) # getting the model transforms
transformed_img = transform(img).unsqueeze(0).to(device)
# Put model into eval mode and make predictions
start = timer() # start timer
model.eval()
with torch.inference_mode():
logits = model(transformed_img)
pred_probs = torch.softmax(logits, dim=1)
pred_labels = torch.argmax(pred_probs, dim=1)
# Creating a prediction label and a preiction probability dict
pred_dict = {
k: v for k, v in zip(class_names, pred_probs.squeeze(0).cpu().tolist())
}
end = timer()
pred_time = round(end - start, 4)
return (pred_dict, pred_time)
# Creating a list of exmaple images for Gradio Demo
example_list = [
["examples/" + example]
for example in os.listdir("examples")
]
title = "Food Vision Mini 🍕🥩🍣"
description = "An efficientNetB2 feature extractor computr vision model to classify images as pizza, steak and sushi."
article = "Created at [09.Pytorch Model Deployement.](https://www.learnpytorch.io/09_pytorch_model_deployment/)"
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=3, label="Predictions"),
gr.Number(label="Prediction Time(s)"),
],
examples=example_list,
title=title,
description=description,
article=article,
)
demo.launch(debug=False, share=True)
|