ResNet18-model / app.py
Sandhya
add requirements
594c7c2
import torchvision
import torch
from torch import nn
from PIL import Image
from torchvision import transforms
import numpy as np
import gradio as gr
def predict(img_path,model=None):
if model is None:
pretrained_weights_resnet18=torchvision.models.ResNet18_Weights.DEFAULT
model=torchvision.models.resnet18(weights=pretrained_weights_resnet18)
class_names=pretrained_weights_resnet18.meta["categories"]
transform=transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
if isinstance(img_path,np.ndarray):
image=Image.fromarray(img_path).convert("RGB")
else:
image=Image.open(img_path).convert("RGB")
img_transform=transform(image).unsqueeze(0)
model.eval()
with torch.inference_mode():
logit=model(img_transform)
pred_prob=torch.softmax(logit,dim=1).squeeze().numpy()
predict_dict={}
for i in range(len(class_names)):
predict_dict[class_names[i]]=float(pred_prob[i])
return predict_dict
demo = gr.Interface(predict, gr.Image(), outputs=gr.Label(num_top_classes=3))
if __name__ == "__main__":
demo.launch()