### 1. Imports and class names setup ### import gradio as gr import os import torch from pathlib import Path from timeit import default_timer as timer from typing import Tuple, Dict from torchvision import transforms class_names=['meme', 'non-meme'] model_path=Path("efficientNet_clf.pt") model = torch.jit.load(model_path,map_location=torch.device('cpu')) image_transform = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) print(image_transform) def predict(img) -> Tuple[Dict, float]: """Transforms and performs a prediction on img and returns prediction and time taken. """ #print("---img path is: ",img) start_time = timer() model.to("cpu") model.eval() with torch.inference_mode(): img = image_transform(img).unsqueeze(dim=0) pred_probs = torch.softmax(model(img).to("cpu"), dim=1) pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))} pred_time = round(timer() - start_time, 5) return pred_labels_and_probs, pred_time #print(e) #return "error",0 title = "Meme classifiication" description = "An EfficientNetB2 model to classify images of food into 2 classes:meme and non-meme" example_list = ["./example_imgs/"+i for i in os.listdir("./example_imgs")] #print(example_list) demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=[ gr.Label(num_top_classes=2, label="Predictions"), gr.Number(label="Prediction time (s)"), ], examples=example_list, title=title, description=description, ) demo.launch() #predict(example_list[0])