| import gradio | |
| import torch | |
| import torchvision | |
| from model import MiniVisionV3 | |
| from PIL import Image, ImageOps | |
| old_classes = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, 'A': 10, 'B': 11, 'C': 12, 'D': 13, 'E': 14, 'F': 15, 'G': 16, 'H': 17, 'I': 18, 'J': 19, 'K': 20, 'L': 21, 'M': 22, 'N': 23, 'O': 24, 'P': 25, 'Q': 26, 'R': 27, 'S': 28, 'T': 29, 'U': 30, 'V': 31, 'W': 32, 'X': 33, 'Y': 34, 'Z': 35, 'a': 36, 'b': 37, 'd': 38, 'e': 39, 'f': 40, 'g': 41, 'h': 42, 'n': 43, 'q': 44, 'r': 45, 't': 46} | |
| classes = {v: k for k, v in old_classes.items()} | |
| transform = torchvision.transforms.Compose([ | |
| torchvision.transforms.Resize(28), | |
| torchvision.transforms.ToTensor()]) | |
| def load_model(): | |
| minivisionv3 = MiniVisionV3() | |
| state_dict = torch.load("Mini-Vision-V3.pth", weights_only=False) | |
| minivisionv3.load_state_dict(state_dict) | |
| minivisionv3.eval() | |
| return minivisionv3 | |
| minivisionv3 = load_model() | |
| def inference(img): | |
| img_convert = ImageOps.invert(img["composite"]) | |
| input = transform(img_convert) | |
| input = input.unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = minivisionv3(input) | |
| prob = torch.softmax(outputs, 1) | |
| result = {} | |
| for i in range(47): | |
| result[str(classes[i])] = prob[0][i].item() | |
| return result | |
| demo = gradio.Interface(fn=inference, | |
| inputs=gradio.Sketchpad(height=560, width=560, image_mode="L", label="Draw Here", type="pil"), | |
| outputs=gradio.Label(label="Results"), | |
| title="Mini-Vision-V3", | |
| description="A lightweight CNN (0.4M params) trained on EMNIST Balanced for handwritten character recognition.") | |
| if __name__ == '__main__': | |
| demo.launch() | |