File size: 1,214 Bytes
d128a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10adfee
d128a86
 
 
ca5dc4d
d128a86
ca5dc4d
d128a86
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from pathlib import Path

import gradio as gr 
from PIL import Image
import numpy as np
from mnist_model.database import transform
from mnist_model.train import MNISTModel
import torch


current_dir = Path(__file__).parent.absolute()

model_path = current_dir / "data" / "models" / "mnist_resnet18_4epochs.ckpt"

model = MNISTModel.load_from_checkpoint(str(model_path)).model

model.eval()


def classify_image(inp):
    img = inp["composite"]

    img = np.array(img)

    # take onliy the last channel
    img = img[:, :, -1]

    img = Image.fromarray(img)

    img = transform()(img)

    img = img.unsqueeze(0)

    pred = model(img)

    pred = torch.softmax(pred, dim=1)

    return {str(i): float(pred[0][i]) for i in range(10)}
    

inputs = gr.Sketchpad(label="Draw a number", crop_size=(28, 28), type="pil")
outputs = gr.Label(num_top_classes=5)

title = "MNIST"

description = "A simple number recognition example model"

article = "<p style='text-align: center'>Here is a simple implementation of number recognition using MNIST dataset and a Resnet18 backbone.<a href="

gr.Interface(fn=classify_image, inputs=inputs, outputs=outputs, title=title, description=description, article=article).launch()