Spaces:
Build error
Build error
File size: 1,224 Bytes
8b252d4 594c7c2 8b252d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torchvision
import torch
from torch import nn
from PIL import Image
from torchvision import transforms
import numpy as np
import gradio as gr
def predict(img_path,model=None):
if model is None:
pretrained_weights_resnet18=torchvision.models.ResNet18_Weights.DEFAULT
model=torchvision.models.resnet18(weights=pretrained_weights_resnet18)
class_names=pretrained_weights_resnet18.meta["categories"]
transform=transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
if isinstance(img_path,np.ndarray):
image=Image.fromarray(img_path).convert("RGB")
else:
image=Image.open(img_path).convert("RGB")
img_transform=transform(image).unsqueeze(0)
model.eval()
with torch.inference_mode():
logit=model(img_transform)
pred_prob=torch.softmax(logit,dim=1).squeeze().numpy()
predict_dict={}
for i in range(len(class_names)):
predict_dict[class_names[i]]=float(pred_prob[i])
return predict_dict
demo = gr.Interface(predict, gr.Image(), outputs=gr.Label(num_top_classes=3))
if __name__ == "__main__":
demo.launch() |