ResNet18-model / app.py
Sandhya
First commit
8b252d4
raw
history blame
1.37 kB
import torchvision
import torch
from torch import nn
from PIL import Image
from torchvision import transforms
import onnxruntime
import numpy as np
import torch.nn.functional as F
from safetensors.torch import save_file,load_file,safe_open
import numpy as np
def predict(img_path,model=None):
if model is None:
pretrained_weights_resnet18=torchvision.models.ResNet18_Weights.DEFAULT
model=torchvision.models.resnet18(weights=pretrained_weights_resnet18)
class_names=pretrained_weights_resnet18.meta["categories"]
transform=transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
if isinstance(img_path,np.ndarray):
image=Image.fromarray(img_path).convert("RGB")
else:
image=Image.open(img_path).convert("RGB")
img_transform=transform(image).unsqueeze(0)
model.eval()
with torch.inference_mode():
logit=model(img_transform)
pred_prob=torch.softmax(logit,dim=1).squeeze().numpy()
predict_dict={}
for i in range(len(class_names)):
predict_dict[class_names[i]]=float(pred_prob[i])
return predict_dict
import numpy as np
import gradio as gr
demo = gr.Interface(predict, gr.Image(), outputs=gr.Label(num_top_classes=3))
if __name__ == "__main__":
demo.launch()