Spaces:
Runtime error
Runtime error
File size: 1,210 Bytes
c3e8ec5 24ada56 c3e8ec5 091c4a9 c3e8ec5 24ada56 091c4a9 24ada56 adc6366 24ada56 25c59d0 24ada56 9f93a31 24ada56 970f569 c3f29d4 24ada56 6a4757a d2a29c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import gradio as gr
import torch
from torchvision import datasets, models, transforms
from PIL import Image
LABELS = ['Fiat 500', 'VW Up!']
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 2)
state_dict = torch.load('up500Model.pt', map_location='cpu')
model.load_state_dict(state_dict)
model.eval()
title = "VW Up! or Fiat 500"
description = "Demo for classification of automobiles. To use it, simply upload your image, or click one of the examples to load them."
imgTransforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def predict(inp):
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
inp = imgTransforms(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0])
return {LABELS[i]: float(prediction[i]) for i in range(2)}
examples = [['fiat500.jpg'],['VWUP.jpg']]
interface = gr.Interface(predict, inputs='image', outputs="label", title=title, description=description, examples=examples, cache_examples=False)
interface.launch() |