import torchvision import torch from torch import nn from PIL import Image from torchvision import transforms import numpy as np import gradio as gr def predict(img_path,model=None): if model is None: pretrained_weights_resnet18=torchvision.models.ResNet18_Weights.DEFAULT model=torchvision.models.resnet18(weights=pretrained_weights_resnet18) class_names=pretrained_weights_resnet18.meta["categories"] transform=transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]) if isinstance(img_path,np.ndarray): image=Image.fromarray(img_path).convert("RGB") else: image=Image.open(img_path).convert("RGB") img_transform=transform(image).unsqueeze(0) model.eval() with torch.inference_mode(): logit=model(img_transform) pred_prob=torch.softmax(logit,dim=1).squeeze().numpy() predict_dict={} for i in range(len(class_names)): predict_dict[class_names[i]]=float(pred_prob[i]) return predict_dict demo = gr.Interface(predict, gr.Image(), outputs=gr.Label(num_top_classes=3)) if __name__ == "__main__": demo.launch()