FoodVision-Big / app.py
mrunalmania's picture
app.py -> class name path misconfigured
37eb346 verified
### 1. Imports and class names setup ###
import gradio as gr
import torch
from PIL import Image
import torchvision
from torchvision import transforms
import torch.nn.functional as F
import os
from model import create_effnetb2_model
from timeit import default_timer as timer
from typing import Tuple, Dict
# Setup the classnames
with open("class_names.txt", "r") as f:
class_names = [food.strip() for food in f.readlines()]
# Create model and transforms
effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=len(class_names))
# load the save weights
effnetb2.load_state_dict(torch.load(f="09_pretrained_effnetb2_extractor_food101_20_percent.pth",
map_location=torch.device("cpu")))
### predicti function
def predict(img) -> Tuple[Dict, float]:
# here float is we need prediction time as output also.
# start a timer
start_time = timer()
# transform the input image
img = effnetb2_transforms(img).unsqueeze(0) # adding a batch dimension
# put model into eval mode
effnetb2.eval()
with torch.inference_mode():
# forward pass
logits = effnetb2(img)
pred_prob = torch.softmax(logits, dim=1)
pred_lables_and_probs = {class_names[i]: float(pred_prob[0][i]) for i in range(len(class_names))}
# make prediction
end_time = timer()
pred_time = round(end_time - start_time,4)
return pred_lables_and_probs, pred_time
### Gradio app.
# Create title, desc, article
title = "Food Vision Big"
description = "An effientnetb2 model to classify food images of classes 101"
article = "Created at pytorch_model_deployment"
# create the example list (list of list)
example_list = [['examples/' + example] for example in os.listdir("examples")]
# Create a gradio demo
demo = gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs = [gr.Label(num_top_classes=5, label="Predictions"), gr.Number(label="Prediction time (s)")],
examples = example_list,
title = title,
description = description,
article = article
)
# Launch the demo
demo.launch(debug=True)