Spaces:
Runtime error
Runtime error
| # Import and class names setup | |
| import gradio as gr | |
| import os | |
| import torch | |
| from model import create_effnetb2_model | |
| from timeit import default_timer as timer | |
| from typing import Tuple, Dict | |
| # Setup class names | |
| with open('class_names.txt', 'r') as f: | |
| class_names= [food_name.strip() for food_name in f.readlines()] | |
| # Model and transforms preparation | |
| effnetb2_model, effnetb2_transform= create_effnetb2_model() | |
| # Load state dict | |
| effnetb2_model.load_state_dict(torch.load( | |
| f= 'effnetb2_food101_model.pth', | |
| map_location= torch.device('cpu') | |
| ) | |
| ) | |
| # Predict function | |
| def predict(img)-> Tuple[Dict, float]: | |
| # start a timer | |
| start_time= timer() | |
| #transform the input image for use with effnet b2 | |
| transform_image= effnetb2_transform(img).unsqueeze(0) | |
| #put model into eval mode, make pred | |
| effnetb2_model.eval() | |
| with torch.inference_mode(): | |
| pred_logits= effnetb2_model(transform_image) | |
| pred_prob= torch.softmax(pred_logits, dim=1) | |
| # create a pred label and pred prob dict | |
| pred_label_and_prob= {class_names[i]: float(pred_prob[0][i]) for i in range(len(class_names))} | |
| # calc pred time | |
| stop_time= timer() | |
| pred_time= round(stop_time - start_time, 4) | |
| # return pred dict and pred time | |
| return pred_label_and_prob, pred_time | |
| # create example list | |
| example_list= [['example/'+example] for example in os.listdir('example')] | |
| # create gradio app | |
| title= 'FoodVision Large 🍕🥩🍣 ' | |
| description= 'An EfficientnetB2 feature extractor Computer vision model to classify 101 classes of food from the food 101 image dataset' | |
| article= 'Created at [To be uploaded].' | |
| # Create the gradio demo | |
| demo= gr.Interface(fn= predict, | |
| inputs=gr.Image(type='pil'), | |
| outputs= [gr.Label(num_top_classes=5, label= 'predictions'), | |
| gr.Number(label= 'Prediction time (S)')], | |
| examples= example_list, | |
| title= title, | |
| description= description, | |
| article= article | |
| ) | |
| # Launch the demo | |
| demo.launch() | |