Mini-Vision-V3 / demo.py
LWWZH's picture
Upload Mini-Vision-V3 Model
68ea1b0 verified
import gradio
import torch
import torchvision
from model import MiniVisionV3
from PIL import Image, ImageOps
old_classes = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, 'A': 10, 'B': 11, 'C': 12, 'D': 13, 'E': 14, 'F': 15, 'G': 16, 'H': 17, 'I': 18, 'J': 19, 'K': 20, 'L': 21, 'M': 22, 'N': 23, 'O': 24, 'P': 25, 'Q': 26, 'R': 27, 'S': 28, 'T': 29, 'U': 30, 'V': 31, 'W': 32, 'X': 33, 'Y': 34, 'Z': 35, 'a': 36, 'b': 37, 'd': 38, 'e': 39, 'f': 40, 'g': 41, 'h': 42, 'n': 43, 'q': 44, 'r': 45, 't': 46}
classes = {v: k for k, v in old_classes.items()}
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(28),
torchvision.transforms.ToTensor()])
def load_model():
minivisionv3 = MiniVisionV3()
state_dict = torch.load("Mini-Vision-V3.pth", weights_only=False)
minivisionv3.load_state_dict(state_dict)
minivisionv3.eval()
return minivisionv3
minivisionv3 = load_model()
def inference(img):
img_convert = ImageOps.invert(img["composite"])
input = transform(img_convert)
input = input.unsqueeze(0)
with torch.no_grad():
outputs = minivisionv3(input)
prob = torch.softmax(outputs, 1)
result = {}
for i in range(47):
result[str(classes[i])] = prob[0][i].item()
return result
demo = gradio.Interface(fn=inference,
inputs=gradio.Sketchpad(height=560, width=560, image_mode="L", label="Draw Here", type="pil"),
outputs=gradio.Label(label="Results"),
title="Mini-Vision-V3",
description="A lightweight CNN (0.4M params) trained on EMNIST Balanced for handwritten character recognition.")
if __name__ == '__main__':
demo.launch()