Spaces:
Runtime error
Runtime error
File size: 1,116 Bytes
621127e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
from torchvision import models, transforms, datasets
from PIL import Image
import gradio as gr
LABELS = ['fiat 500', 'VW Up!']
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_ft.fc = nn.Linear(num_ftrs, 2)
state_dict = torch.load('up500Model.pt', map_location='cpu')
model_ft.load_state_dict(state_dict)
model_ft.eval()
imgTransforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def predict(inp):
inp = Image.fromarray(inp.astype('unit8'), 'RGB')
inp = imgTransforms(inp).unsqueeze(0)
with torch.no_grad():
predictions = torch.nn.functional.softmax(model_ft(inp)[0])
return {LABELS[i]: float(predictions[i]) for i in range(2)}
interface = gr.Interface(predict, inputs='image', outputs='label', title='Car classification')
interface.launch()
|