Spaces:
Runtime error
Runtime error
| import torch | |
| from torchvision import models, transforms, datasets | |
| from PIL import Image | |
| import gradio as gr | |
| LABELS = ['fiat 500', 'VW Up!'] | |
| model_ft = models.resnet18(pretrained=True) | |
| num_ftrs = model_ft.fc.in_features | |
| # Here the size of each output sample is set to 2. | |
| # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)). | |
| model_ft.fc = nn.Linear(num_ftrs, 2) | |
| state_dict = torch.load('up500Model.pt', map_location='cpu') | |
| model_ft.load_state_dict(state_dict) | |
| model_ft.eval() | |
| imgTransforms = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) | |
| def predict(inp): | |
| inp = Image.fromarray(inp.astype('unit8'), 'RGB') | |
| inp = imgTransforms(inp).unsqueeze(0) | |
| with torch.no_grad(): | |
| predictions = torch.nn.functional.softmax(model_ft(inp)[0]) | |
| return {LABELS[i]: float(predictions[i]) for i in range(2)} | |
| interface = gr.Interface(predict, inputs='image', outputs='label', title='Car classification') | |
| interface.launch() | |