LWWZH commited on
Commit
c8951e2
·
verified ·
1 Parent(s): 83becf1

Upload Mini-Vision-V3 demo

Browse files
Files changed (3) hide show
  1. Mini-Vision-V3.pth +3 -0
  2. demo.py +46 -0
  3. model.py +44 -0
Mini-Vision-V3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31ef54c30c43db9ab2954812d65cb34801f5583964ebb6a90195246f867a2036
3
+ size 1612075
demo.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+ import torch
3
+ import torchvision
4
+ from model import MiniVisionV3
5
+ from PIL import Image, ImageOps
6
+
7
+
8
+ old_classes = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, 'A': 10, 'B': 11, 'C': 12, 'D': 13, 'E': 14, 'F': 15, 'G': 16, 'H': 17, 'I': 18, 'J': 19, 'K': 20, 'L': 21, 'M': 22, 'N': 23, 'O': 24, 'P': 25, 'Q': 26, 'R': 27, 'S': 28, 'T': 29, 'U': 30, 'V': 31, 'W': 32, 'X': 33, 'Y': 34, 'Z': 35, 'a': 36, 'b': 37, 'd': 38, 'e': 39, 'f': 40, 'g': 41, 'h': 42, 'n': 43, 'q': 44, 'r': 45, 't': 46}
9
+ classes = {v: k for k, v in old_classes.items()}
10
+
11
+ transform = torchvision.transforms.Compose([
12
+ torchvision.transforms.Resize(28),
13
+ torchvision.transforms.ToTensor()])
14
+
15
+ def load_model():
16
+ minivisionv3 = MiniVisionV3()
17
+ state_dict = torch.load("Mini-Vision-V3.pth", weights_only=False)
18
+ minivisionv3.load_state_dict(state_dict)
19
+ minivisionv3.eval()
20
+ return minivisionv3
21
+
22
+ minivisionv3 = load_model()
23
+
24
+ def inference(img):
25
+ img_convert = ImageOps.invert(img["composite"])
26
+ input = transform(img_convert)
27
+ input = input.unsqueeze(0)
28
+
29
+ with torch.no_grad():
30
+ outputs = minivisionv3(input)
31
+ prob = torch.softmax(outputs, 1)
32
+
33
+ result = {}
34
+ for i in range(47):
35
+ result[str(classes[i])] = prob[0][i].item()
36
+ return result
37
+
38
+
39
+ demo = gradio.Interface(fn=inference,
40
+ inputs=gradio.Sketchpad(height=560, width=560, image_mode="L", label="Draw Here", type="pil"),
41
+ outputs=gradio.Label(label="Results"),
42
+ title="Mini-Vision-V3",
43
+ description="A lightweight CNN (0.4M params) trained on EMNIST Balanced for handwritten character recognition.")
44
+
45
+ if __name__ == '__main__':
46
+ demo.launch()
model.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class MiniVisionV3(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.model = nn.Sequential(
9
+ nn.Conv2d(1, 32, 3, padding=1),
10
+ nn.BatchNorm2d(32),
11
+ nn.ReLU(),
12
+ nn.MaxPool2d(2),
13
+
14
+ nn.Conv2d(32, 64, 3, padding=1),
15
+ nn.BatchNorm2d(64),
16
+ nn.ReLU(),
17
+ nn.MaxPool2d(2),
18
+
19
+ nn.Conv2d(64, 128, 3, padding=1),
20
+ nn.BatchNorm2d(128),
21
+ nn.ReLU(),
22
+ nn.MaxPool2d(2),
23
+
24
+ nn.Flatten(),
25
+ nn.Linear(1152, 256),
26
+ nn.ReLU(),
27
+ nn.Dropout(0.3),
28
+ nn.Linear(256, 47),
29
+ )
30
+
31
+ def forward(self, x):
32
+ x = self.model(x)
33
+ return x
34
+
35
+ if __name__ == '__main__':
36
+ minivisionv3 = MiniVisionV3()
37
+
38
+ total_params = sum(param.numel() for param in minivisionv3.parameters())
39
+ print(f"Total params: {total_params / 1000000: .2f}M")
40
+
41
+ # with torch.no_grad():
42
+ # input = torch.randn(256, 1, 28, 28)
43
+ # output = minivisionv3(input)
44
+ # print(output)