File size: 1,813 Bytes
68ea1b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gradio
import torch
import torchvision
from model import MiniVisionV3
from PIL import Image, ImageOps


old_classes = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, 'A': 10, 'B': 11, 'C': 12, 'D': 13, 'E': 14, 'F': 15, 'G': 16, 'H': 17, 'I': 18, 'J': 19, 'K': 20, 'L': 21, 'M': 22, 'N': 23, 'O': 24, 'P': 25, 'Q': 26, 'R': 27, 'S': 28, 'T': 29, 'U': 30, 'V': 31, 'W': 32, 'X': 33, 'Y': 34, 'Z': 35, 'a': 36, 'b': 37, 'd': 38, 'e': 39, 'f': 40, 'g': 41, 'h': 42, 'n': 43, 'q': 44, 'r': 45, 't': 46}
classes = {v: k for k, v in old_classes.items()}

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(28),
    torchvision.transforms.ToTensor()])

def load_model():
    minivisionv3 = MiniVisionV3()
    state_dict = torch.load("Mini-Vision-V3.pth", weights_only=False)
    minivisionv3.load_state_dict(state_dict)
    minivisionv3.eval()
    return minivisionv3

minivisionv3 = load_model()

def inference(img):
    img_convert = ImageOps.invert(img["composite"])
    input = transform(img_convert)
    input = input.unsqueeze(0)

    with torch.no_grad():
        outputs = minivisionv3(input)
        prob = torch.softmax(outputs, 1)

    result = {}
    for i in range(47):
        result[str(classes[i])] = prob[0][i].item()
    return result


demo = gradio.Interface(fn=inference,
                        inputs=gradio.Sketchpad(height=560, width=560, image_mode="L", label="Draw Here", type="pil"),
                        outputs=gradio.Label(label="Results"),
                        title="Mini-Vision-V3",
                        description="A lightweight CNN (0.4M params) trained on EMNIST Balanced for handwritten character recognition.")

if __name__ == '__main__':
    demo.launch()