| from model import MiniVisionV2 | |
| import torch | |
| import torchvision | |
| import gradio as gr | |
| import webbrowser | |
| minivisionv2 = torch.load("Mini-Vision-V2.pth", weights_only=False) | |
| minivisionv2.eval() | |
| transform = torchvision.transforms.Compose([torchvision.transforms.Resize(28), | |
| torchvision.transforms.ToTensor()]) | |
| def classifier(img): | |
| input = transform(img["composite"]) | |
| input = 1.0 - input | |
| tensor = input.unsqueeze(0) | |
| with torch.no_grad(): | |
| output = minivisionv2(tensor) | |
| output = torch.softmax(output, dim=1) | |
| result = {} | |
| for i in range(10): | |
| result[str(i)] = output[0][i].item() | |
| return result | |
| demo = gr.Interface(fn=classifier, | |
| inputs=gr.Sketchpad(height=280, width=280, image_mode="L", label="Sketch Pad", type="pil"), | |
| outputs=gr.Label(label="Classifying Results"), | |
| title="Mini-Vision-V2", | |
| description="Write number 0-9 in the sketch pad below" | |
| ) | |
| if __name__ == '__main__': | |
| webbrowser.open("http://127.0.0.1:7860") | |
| demo.launch(share=True) | |