Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import torchvision.transforms as transforms | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from model import Net | |
| # loads demo model | |
| if torch.cuda.is_available(): | |
| dev = "cuda:0" | |
| else: | |
| dev = "cpu" | |
| device = torch.device(dev) | |
| model = torch.load(f"./demo_model.pt", map_location=device) | |
| model.eval() | |
| # inference function | |
| def inference(img): | |
| transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((28, 28))]) | |
| img = transform(img).unsqueeze(0) # transforms ndarray and adds batch dimension | |
| with torch.no_grad(): | |
| output_probabilities = F.softmax(model(img), dim=1)[0] # probability prediction for each label | |
| return {labels[i]: float(output_probabilities[i]) for i in range(len(labels))} | |
| # Creates and launches gradio interface | |
| labels = range(10) # 1-9 labels | |
| outputs = gr.outputs.Label(num_top_classes=5) | |
| gr.Interface(fn=inference, inputs='sketchpad', outputs=outputs, title="MNIST Interface", | |
| description="Draw a number from 0-9 in the box and click submit to see the model's predictions.").launch() | |