from pathlib import Path import gradio as gr from PIL import Image import numpy as np from mnist_model.database import transform from mnist_model.train import MNISTModel import torch current_dir = Path(__file__).parent.absolute() model_path = current_dir / "data" / "models" / "mnist_resnet18_4epochs.ckpt" model = MNISTModel.load_from_checkpoint(str(model_path)).model model.eval() def classify_image(inp): img = inp["composite"] img = np.array(img) # take onliy the last channel img = img[:, :, -1] img = Image.fromarray(img) img = transform()(img) img = img.unsqueeze(0) pred = model(img) pred = torch.softmax(pred, dim=1) return {str(i): float(pred[0][i]) for i in range(10)} inputs = gr.Sketchpad(label="Draw a number", crop_size=(28, 28), type="pil") outputs = gr.Label(num_top_classes=5) title = "MNIST" description = "A simple number recognition example model" article = "
Here is a simple implementation of number recognition using MNIST dataset and a Resnet18 backbone.