import gradio as gr import torch from torchvision import datasets, models, transforms from PIL import Image LABELS = ['Fiat 500', 'VW Up!'] model = models.resnet18(pretrained=True) num_ftrs = model.fc.in_features model.fc = torch.nn.Linear(num_ftrs, 2) state_dict = torch.load('up500Model.pt', map_location='cpu') model.load_state_dict(state_dict) model.eval() title = "VW Up! or Fiat 500" description = "Demo for classification of automobiles. To use it, simply upload your image, or click one of the examples to load them." imgTransforms = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) def predict(inp): inp = Image.fromarray(inp.astype('uint8'), 'RGB') inp = imgTransforms(inp).unsqueeze(0) with torch.no_grad(): prediction = torch.nn.functional.softmax(model(inp)[0]) return {LABELS[i]: float(prediction[i]) for i in range(2)} examples = [['fiat500.jpg'],['VWUP.jpg']] interface = gr.Interface(predict, inputs='image', outputs="label", title=title, description=description, examples=examples, cache_examples=False) interface.launch()