FoodVision / app.py
hasorez's picture
first commit
43e60f0
raw
history blame contribute delete
1.28 kB
import torch
import os
import gradio as gr
from model import create_swin
from timeit import default_timer as timer
with open("class_names.txt", "r") as f:
class_names = [food.strip() for food in f.readlines()]
swin, swin_transforms = create_swin(len(class_names))
swin.load_state_dict(
torch.load(
f="pretrained_swin_food101_dataset.pth",
map_location=torch.device("cpu"), # load to CPU
)
)
def predict(img):
start = timer()
img = swin_transforms(img).unsqueeze(0)
swin.eval()
with torch.inference_mode():
pred_probs = torch.softmax(swin(img), dim=1)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
return pred_labels_and_probs, round(timer() - start, 2)
title = "FoodVision πŸ’»πŸ‘οΈ"
description = "A Swin Transformer feature extractor computer vision model for classifying images of food"
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create the Gradio demo
demo = gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=5, label="Predictions"), gr.Number(label="Prediction time (s)")],
examples = example_list,
title=title,
description=description
)
demo.launch()