Spaces:
Sleeping
Sleeping
File size: 1,890 Bytes
9e2f3a9 99a3a02 9e2f3a9 f472d67 e1139db efaee9b 9e2f3a9 1a2d2be 9e2f3a9 5b7ac00 9e2f3a9 c86dc42 9e2f3a9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | ### 1 Imports and class names setup###
import gradio as gr
import os
import torch
from model import create_effnetb2_model
from timeit import default_timer as timer
from typing import List, Dict,Tuple
class_names = ["pizza", "steak", "sushi"]
### 2 model and transform preparation###
effnetb2_loaded, effnet_transform = create_effnetb2_model(num_classes=len(class_names))
effnetb2_loaded.load_state_dict(torch.load("11-model_deployment_effnetb2.pth", map_location="cpu"))
effnetb2_loaded.to("cpu")
### 3 we need a predict function###
def predict(img) -> Tuple[Dict,float]:
#start a timer
start_time = timer()
# transform the image
transformed_image = effnet_transform(img).unsqueeze(0)
# putting the model in eval mode and make the prediction
effnetb2_loaded.eval()
with torch.inference_mode():
logit = effnetb2_loaded(transformed_image)
probs = torch.softmax(logit, dim=1)
# Create a prediction label and prediction probability dictionary
pred_label_dict ={class_names[i] : probs[0][i].item() for i in range(len(class_names))}
# calculate the pred time
end_time = timer()
inference_time = round(end_time - start_time, 4)
# return the label dict and inference time
return pred_label_dict, inference_time
###Grad###
title = "FoodVision mini models 🍕,🥩,🍣"
description = "An EfficientnetB2 feature extraction model is used to classifay images as pizza, steak, sushi"
example_list =[["examples/"+example] for example in os.listdir("examples")]
# create a gradio demo
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes = 3,label= "prediction"),
gr.Number(label=" Prediction time in second")],
examples=example_list,
title=title,
description=description,
cache_examples=False
)
demo.launch(share= False)
|