Spaces:
Build error
Build error
File size: 1,214 Bytes
d128a86 10adfee d128a86 ca5dc4d d128a86 ca5dc4d d128a86 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | from pathlib import Path
import gradio as gr
from PIL import Image
import numpy as np
from mnist_model.database import transform
from mnist_model.train import MNISTModel
import torch
current_dir = Path(__file__).parent.absolute()
model_path = current_dir / "data" / "models" / "mnist_resnet18_4epochs.ckpt"
model = MNISTModel.load_from_checkpoint(str(model_path)).model
model.eval()
def classify_image(inp):
img = inp["composite"]
img = np.array(img)
# take onliy the last channel
img = img[:, :, -1]
img = Image.fromarray(img)
img = transform()(img)
img = img.unsqueeze(0)
pred = model(img)
pred = torch.softmax(pred, dim=1)
return {str(i): float(pred[0][i]) for i in range(10)}
inputs = gr.Sketchpad(label="Draw a number", crop_size=(28, 28), type="pil")
outputs = gr.Label(num_top_classes=5)
title = "MNIST"
description = "A simple number recognition example model"
article = "<p style='text-align: center'>Here is a simple implementation of number recognition using MNIST dataset and a Resnet18 backbone.<a href="
gr.Interface(fn=classify_image, inputs=inputs, outputs=outputs, title=title, description=description, article=article).launch()
|