ResNet18-model / app.py
Sandhya
First commit
24fa9b5
raw
history blame
1.24 kB
import torchvision
import torch
from torch import nn
from PIL import Image
from torchvision import transforms
import numpy as np
def predict(img_path,model=None):
if model is None:
pretrained_weights_resnet18=torchvision.models.ResNet18_Weights.DEFAULT
model=torchvision.models.resnet18(weights=pretrained_weights_resnet18)
class_names=pretrained_weights_resnet18.meta["categories"]
transform=transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
if isinstance(img_path,np.ndarray):
image=Image.fromarray(img_path).convert("RGB")
else:
image=Image.open(img_path).convert("RGB")
img_transform=transform(image).unsqueeze(0)
model.eval()
with torch.inference_mode():
logit=model(img_transform)
pred_prob=torch.softmax(logit,dim=1).squeeze().numpy()
predict_dict={}
for i in range(len(class_names)):
predict_dict[class_names[i]]=float(pred_prob[i])
return predict_dict
import numpy as np
import gradio as gr
demo = gr.Interface(predict, gr.Image(), outputs=gr.Label(num_top_classes=3))
if __name__ == "__main__":
demo.launch()