Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision.models import EfficientNet_B2_Weights | |
| import gradio as gr | |
| from model import EffnetB2 | |
| from timeit import default_timer as timer | |
| from typing import Tuple | |
| import os | |
| device = "cpu" | |
| # laoding our model | |
| model = EffnetB2() | |
| checkpoint = torch.load( | |
| "effnetB2_2025-11-29_epoch4.pt", | |
| map_location=device, | |
| ) | |
| model.load_state_dict(checkpoint) | |
| model.to(device) | |
| class_names = ["pizza", "steak", "sushi"] | |
| def predict(img) -> Tuple[dict, float]: | |
| """_summary_ | |
| Takes an image and make predictions | |
| Args: | |
| img (_type_): An Image | |
| Returns: | |
| Tuple[dict, float]: a dict for the confidence of each class, and float for the inference time | |
| """ | |
| # Transform the image to work with effnetB2 | |
| transform = ( | |
| EfficientNet_B2_Weights.DEFAULT.transforms() | |
| ) # getting the model transforms | |
| transformed_img = transform(img).unsqueeze(0).to(device) | |
| # Put model into eval mode and make predictions | |
| start = timer() # start timer | |
| model.eval() | |
| with torch.inference_mode(): | |
| logits = model(transformed_img) | |
| pred_probs = torch.softmax(logits, dim=1) | |
| pred_labels = torch.argmax(pred_probs, dim=1) | |
| # Creating a prediction label and a preiction probability dict | |
| pred_dict = { | |
| k: v for k, v in zip(class_names, pred_probs.squeeze(0).cpu().tolist()) | |
| } | |
| end = timer() | |
| pred_time = round(end - start, 4) | |
| return (pred_dict, pred_time) | |
| # Creating a list of exmaple images for Gradio Demo | |
| example_list = [ | |
| ["examples/" + example] | |
| for example in os.listdir("examples") | |
| ] | |
| title = "Food Vision Mini 🍕🥩🍣" | |
| description = "An efficientNetB2 feature extractor computr vision model to classify images as pizza, steak and sushi." | |
| article = "Created at [09.Pytorch Model Deployement.](https://www.learnpytorch.io/09_pytorch_model_deployment/)" | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=[ | |
| gr.Label(num_top_classes=3, label="Predictions"), | |
| gr.Number(label="Prediction Time(s)"), | |
| ], | |
| examples=example_list, | |
| title=title, | |
| description=description, | |
| article=article, | |
| ) | |
| demo.launch(debug=False, share=True) | |