mnist_drawing / app.py
tchauffi's picture
ADD: Add top 5
10adfee
from pathlib import Path
import gradio as gr
from PIL import Image
import numpy as np
from mnist_model.database import transform
from mnist_model.train import MNISTModel
import torch
current_dir = Path(__file__).parent.absolute()
model_path = current_dir / "data" / "models" / "mnist_resnet18_4epochs.ckpt"
model = MNISTModel.load_from_checkpoint(str(model_path)).model
model.eval()
def classify_image(inp):
img = inp["composite"]
img = np.array(img)
# take onliy the last channel
img = img[:, :, -1]
img = Image.fromarray(img)
img = transform()(img)
img = img.unsqueeze(0)
pred = model(img)
pred = torch.softmax(pred, dim=1)
return {str(i): float(pred[0][i]) for i in range(10)}
inputs = gr.Sketchpad(label="Draw a number", crop_size=(28, 28), type="pil")
outputs = gr.Label(num_top_classes=5)
title = "MNIST"
description = "A simple number recognition example model"
article = "<p style='text-align: center'>Here is a simple implementation of number recognition using MNIST dataset and a Resnet18 backbone.<a href="
gr.Interface(fn=classify_image, inputs=inputs, outputs=outputs, title=title, description=description, article=article).launch()