Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import torch | |
| from model import create_effnet_b2 | |
| from timeit import default_timer as timer | |
| from typing import Tuple, Dict | |
| #setup class names | |
| class_names = ['pizza', 'steak', 'sushi'] | |
| #model and transforms preparation | |
| effnetb2, effnetb2_transforms = create_effnet_b2( | |
| num_classes = 3) | |
| #load saved weights | |
| effnetb2.load_state_dict( | |
| torch.load(f = 'pretrained_effnetb2_feature_extractor.pth', | |
| map_location = torch.device('cpu')) #hardcoding to load state dict onto the cpu | |
| ) | |
| #Predict function | |
| def predict(img) -> Tuple[Dict, float]: | |
| #Start a timer | |
| start_time = timer() | |
| #transform the input image for use with effnetb2 | |
| transformed_image = effnetb2_transforms(img).unsqueeze(0) | |
| #put model into deval mode, make preiction | |
| effnetb2.eval() | |
| with torch.inference_mode(): | |
| pred_logits = effnetb2(transformed_image) | |
| pred_probs = torch.softmax(pred_logits, dim = 1) | |
| # create a prediction label and pred prob dictionary | |
| pred_labels_and_probs = {effnet_class_names[i]: float(pred_probs[0][i]) | |
| for i in range(len(effnet_class_names))} | |
| #calculate pred time | |
| end_time = timer() | |
| pred_time = end_time - start_time | |
| #return pred dict and pred time | |
| print(pred_probs[0]) | |
| return pred_labels_and_probs, pred_time | |
| # Gradio app | |
| import gradio as gr | |
| #Create title, description and article | |
| title = 'FoodVision Mini' | |
| description = 'An EfficientNetB2 feature extractor to classify food as pizza, steak, and sushi' | |
| #Create example list | |
| example_list = [['examples/' + example] for example in os.listdir('examples')] | |
| demo = gr.Interface(fn = predict, | |
| inputs = gr.Image(type='pil'), | |
| outputs = [gr.Label(num_top_classes = 3, label = 'Predictions'), | |
| gr.Number(label = 'Prediction time (s)')], | |
| examples = example_list, | |
| title = title, | |
| description = description) | |
| demo.launch(debug = False, | |
| share = True) | |