ResNet18-model / app.py
Sandhya
add requirements
594c7c2
raw
history blame
1.22 kB
import torchvision
import torch
from torch import nn
from PIL import Image
from torchvision import transforms
import numpy as np
import gradio as gr
def predict(img_path,model=None):
if model is None:
pretrained_weights_resnet18=torchvision.models.ResNet18_Weights.DEFAULT
model=torchvision.models.resnet18(weights=pretrained_weights_resnet18)
class_names=pretrained_weights_resnet18.meta["categories"]
transform=transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
if isinstance(img_path,np.ndarray):
image=Image.fromarray(img_path).convert("RGB")
else:
image=Image.open(img_path).convert("RGB")
img_transform=transform(image).unsqueeze(0)
model.eval()
with torch.inference_mode():
logit=model(img_transform)
pred_prob=torch.softmax(logit,dim=1).squeeze().numpy()
predict_dict={}
for i in range(len(class_names)):
predict_dict[class_names[i]]=float(pred_prob[i])
return predict_dict
demo = gr.Interface(predict, gr.Image(), outputs=gr.Label(num_top_classes=3))
if __name__ == "__main__":
demo.launch()