LWWZH commited on
Commit
68ea1b0
·
verified ·
1 Parent(s): 8df070e

Upload Mini-Vision-V3 Model

Browse files
Files changed (5) hide show
  1. Mini-Vision-V3.pth +3 -0
  2. README.md +121 -0
  3. demo.py +46 -0
  4. model.py +44 -0
  5. train.py +95 -0
Mini-Vision-V3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31ef54c30c43db9ab2954812d65cb34801f5583964ebb6a90195246f867a2036
3
+ size 1612075
README.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: pytorch
4
+ tags:
5
+ - image-classification
6
+ - emnist
7
+ - cnn
8
+ - computer-vision
9
+ - pytorch
10
+ - mini-vision
11
+ - mini-vision-series
12
+ metrics:
13
+ - accuracy
14
+ pipeline_tag: image-classification
15
+ datasets:
16
+ - emnist
17
+ ---
18
+
19
+ # Mini-Vision-V3: EMNIST Balanced Handwritten Character Classifier
20
+
21
+ ![Model Size](https://img.shields.io/badge/Params-0.40M-blue) ![Accuracy](https://img.shields.io/badge/Accuracy-90.06%25-green)
22
+
23
+ Welcome to **Mini-Vision-V3**, the third model in the Mini-Vision series. Following the MNIST digit recognition task in V2, this model expands capabilities to **47 classes** of handwritten characters (Digits & Uppercase & Lowercase letters) using the EMNIST Balanced dataset. It features a deeper yet highly efficient 3-layer CNN architecture, achieving over 90% accuracy with less than half a million parameters.
24
+
25
+ ## Model Description
26
+
27
+ Mini-Vision-V3 is a custom 3-layer CNN architecture tailored for 28x28 grayscale images. While maintaining a lightweight footprint with only **0.40M parameters** (half the size of V2), it handles the significantly increased complexity of 47 character classes. This project demonstrates how depth and Batch Normalization can improve performance on more complex classification tasks without increasing model size.
28
+
29
+ - **Dataset**: [EMNIST Balanced](https://www.nist.gov/itl/products-and-services/emnist-dataset) (28x28 grayscale images, 47 classes)
30
+ - **Framework**: PyTorch
31
+ - **Total Parameters**: 0.40M
32
+
33
+ ## Model Architecture
34
+
35
+ The network utilizes a deeper structure compared to V2, featuring three convolutional blocks. This allows for better feature extraction in the more complex 47-class task.
36
+
37
+ | Layer | Input Channels | Output Channels | Kernel Size | Stride | Padding | Activation | Other |
38
+ | :--- | :---: | :---: | :---: | :---: | :---: | :--- | :--- |
39
+ | **Conv Block 1** | 1 | 32 | 3 | 1 | 1 | ReLU | MaxPool(2), BatchNorm |
40
+ | **Conv Block 2** | 32 | 64 | 3 | 1 | 1 | ReLU | MaxPool(2), BatchNorm |
41
+ | **Conv Block 3** | 64 | 128 | 3 | 1 | 1 | ReLU | MaxPool(2), BatchNorm |
42
+ | **Flatten** | - | - | - | - | - | - | Output: 1152 |
43
+ | **Linear 1** | 1152 | 256 | - | - | - | ReLU | Dropout(0.3) |
44
+ | **Linear 2** | 256 | 47 | - | - | - | - | - |
45
+
46
+ ## Training Strategy
47
+
48
+ The training strategy was adjusted for the larger dataset and increased class complexity, utilizing a higher initial learning rate and a StepLR scheduler for convergence.
49
+
50
+ - **Optimizer**: SGD (Momentum=0.8)
51
+ - **Initial Learning Rate**: 0.05
52
+ - **Scheduler**: StepLR (Step size=5, Gamma=0.5)
53
+ - **Loss Function**: CrossEntropyLoss
54
+ - **Batch Size**: 256
55
+ - **Epochs**: 50 (Best model at Epoch 40)
56
+ - **Data Preprocessing**:
57
+ - EMNIST specific alignment: Rotate -90 degrees and Flip Horizontal (to match standard image orientation).
58
+ - Random Crop (28x28 with padding=2)
59
+ - Random Rotation (10 degrees)
60
+
61
+ ## Performance
62
+
63
+ The model achieved solid results on the EMNIST Balanced test set (18800 samples), selected based on the best performing epoch (Epoch 40):
64
+
65
+ | Metric | Value |
66
+ | :--- | :---: |
67
+ | **Test Accuracy** | **90.06%** |
68
+ | Test Loss | 0.28 |
69
+ | Train Loss | 0.28 |
70
+ | Parameters | 0.40M |
71
+
72
+ ### Training Visualization (TensorBoard)
73
+
74
+ Below are the training and testing curves visualized via TensorBoard.
75
+
76
+ #### 1. Training Loss
77
+
78
+ ![Training Loss](assets/train_loss.png)
79
+ *(Recorded every epoch)*
80
+
81
+ #### 2. Test Loss & Accuracy
82
+ ![Test Loss](assets/test_loss.png)
83
+ *(Recorded every epoch)*
84
+
85
+ ## Quick Start
86
+
87
+ ### Dependencies
88
+ - Python 3.x
89
+ - PyTorch
90
+ - Torchvision
91
+ - Gradio (for demo)
92
+ - Pillow
93
+
94
+ ### Inference / Web Demo
95
+
96
+ Run the Gradio demo to draw characters and see predictions in real-time:
97
+
98
+ ```bash
99
+ python demo.py
100
+ ```
101
+
102
+ *Note: The demo supports inverted drawing (white ink on black background) to match the EMNIST format.*
103
+
104
+ ## File Structure
105
+
106
+ ```
107
+ .
108
+ ├── model.py # Model architecture definition (MiniVisionV3)
109
+ ├── train.py # Training script
110
+ ├── demo.py # Gradio Web Interface
111
+ ├── Mini-Vision-V3.pth # Trained model weights (Epoch 40)
112
+ ├── config.json
113
+ ├── README.md
114
+ └── assets
115
+ ├── train_loss.png # Visualized train loss graph
116
+ └── test_loss.png # Visualized test loss graph
117
+ ```
118
+
119
+ ## License
120
+
121
+ This project is licensed under the MIT License.
demo.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+ import torch
3
+ import torchvision
4
+ from model import MiniVisionV3
5
+ from PIL import Image, ImageOps
6
+
7
+
8
+ old_classes = {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, 'A': 10, 'B': 11, 'C': 12, 'D': 13, 'E': 14, 'F': 15, 'G': 16, 'H': 17, 'I': 18, 'J': 19, 'K': 20, 'L': 21, 'M': 22, 'N': 23, 'O': 24, 'P': 25, 'Q': 26, 'R': 27, 'S': 28, 'T': 29, 'U': 30, 'V': 31, 'W': 32, 'X': 33, 'Y': 34, 'Z': 35, 'a': 36, 'b': 37, 'd': 38, 'e': 39, 'f': 40, 'g': 41, 'h': 42, 'n': 43, 'q': 44, 'r': 45, 't': 46}
9
+ classes = {v: k for k, v in old_classes.items()}
10
+
11
+ transform = torchvision.transforms.Compose([
12
+ torchvision.transforms.Resize(28),
13
+ torchvision.transforms.ToTensor()])
14
+
15
+ def load_model():
16
+ minivisionv3 = MiniVisionV3()
17
+ state_dict = torch.load("Mini-Vision-V3.pth", weights_only=False)
18
+ minivisionv3.load_state_dict(state_dict)
19
+ minivisionv3.eval()
20
+ return minivisionv3
21
+
22
+ minivisionv3 = load_model()
23
+
24
+ def inference(img):
25
+ img_convert = ImageOps.invert(img["composite"])
26
+ input = transform(img_convert)
27
+ input = input.unsqueeze(0)
28
+
29
+ with torch.no_grad():
30
+ outputs = minivisionv3(input)
31
+ prob = torch.softmax(outputs, 1)
32
+
33
+ result = {}
34
+ for i in range(47):
35
+ result[str(classes[i])] = prob[0][i].item()
36
+ return result
37
+
38
+
39
+ demo = gradio.Interface(fn=inference,
40
+ inputs=gradio.Sketchpad(height=560, width=560, image_mode="L", label="Draw Here", type="pil"),
41
+ outputs=gradio.Label(label="Results"),
42
+ title="Mini-Vision-V3",
43
+ description="A lightweight CNN (0.4M params) trained on EMNIST Balanced for handwritten character recognition.")
44
+
45
+ if __name__ == '__main__':
46
+ demo.launch()
model.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class MiniVisionV3(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.model = nn.Sequential(
9
+ nn.Conv2d(1, 32, 3, padding=1),
10
+ nn.BatchNorm2d(32),
11
+ nn.ReLU(),
12
+ nn.MaxPool2d(2),
13
+
14
+ nn.Conv2d(32, 64, 3, padding=1),
15
+ nn.BatchNorm2d(64),
16
+ nn.ReLU(),
17
+ nn.MaxPool2d(2),
18
+
19
+ nn.Conv2d(64, 128, 3, padding=1),
20
+ nn.BatchNorm2d(128),
21
+ nn.ReLU(),
22
+ nn.MaxPool2d(2),
23
+
24
+ nn.Flatten(),
25
+ nn.Linear(1152, 256),
26
+ nn.ReLU(),
27
+ nn.Dropout(0.3),
28
+ nn.Linear(256, 47),
29
+ )
30
+
31
+ def forward(self, x):
32
+ x = self.model(x)
33
+ return x
34
+
35
+ if __name__ == '__main__':
36
+ minivisionv3 = MiniVisionV3()
37
+
38
+ total_params = sum(param.numel() for param in minivisionv3.parameters())
39
+ print(f"Total params: {total_params / 1000000: .2f}M")
40
+
41
+ # with torch.no_grad():
42
+ # input = torch.randn(256, 1, 28, 28)
43
+ # output = minivisionv3(input)
44
+ # print(output)
train.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torchvision
5
+ from model import MiniVisionV3
6
+ from torch.utils.data import DataLoader
7
+ from torchvision.transforms import functional as F
8
+ from tqdm import tqdm
9
+ from torch.utils.tensorboard import SummaryWriter
10
+
11
+
12
+ # Global config
13
+ epoch = 50
14
+ learningrate = 5e-2
15
+ batchsize = 256
16
+
17
+
18
+ save_folder = "MiniVisionV3"
19
+ writer = SummaryWriter("../MiniVisionV3_log")
20
+
21
+ if not os.path.exists(save_folder):
22
+ os.mkdir(save_folder)
23
+
24
+ transform_correct_train = torchvision.transforms.Compose([
25
+ torchvision.transforms.Lambda(lambda x: F.rotate(x, -90)),
26
+ F.hflip,
27
+ torchvision.transforms.RandomCrop(28, 2),
28
+ torchvision.transforms.RandomRotation(10),
29
+ torchvision.transforms.ToTensor()
30
+ ])
31
+ transform_correct_test = torchvision.transforms.Compose([
32
+ torchvision.transforms.Lambda(lambda x: F.rotate(x, -90)),
33
+ F.hflip,
34
+ torchvision.transforms.ToTensor()
35
+ ])
36
+
37
+ dataset_train = torchvision.datasets.EMNIST("../EMNIST_train", "balanced", train=True, download=True, transform=transform_correct_train)
38
+ dataset_test = torchvision.datasets.EMNIST("../EMNIST_test", "balanced", train=False, download=True, transform=transform_correct_test)
39
+
40
+ dataloader_train = DataLoader(dataset_train, batchsize, True)
41
+ dataloader_test = DataLoader(dataset_test, batchsize, False)
42
+
43
+
44
+ minivisionv3 = MiniVisionV3()
45
+
46
+ loss_fn = torch.nn.CrossEntropyLoss()
47
+ optimizer = torch.optim.SGD(minivisionv3.parameters(), lr=learningrate, momentum=0.8)
48
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.5)
49
+
50
+ train_datasize = len(dataset_train)
51
+ test_datasize = len(dataset_test)
52
+ print(f"Train dataset size: {train_datasize}")
53
+ print(f"test dataset size: {test_datasize}")
54
+
55
+ for i in range(epoch):
56
+ print(f"=============== Epoch {i} Start | LR: {optimizer.param_groups[0]['lr']} ===============")
57
+
58
+ minivisionv3.train()
59
+ total_train_loss = 0
60
+ for data in tqdm(dataloader_train,file=sys.stdout):
61
+ optimizer.zero_grad()
62
+ imgs, labels = data
63
+ output = minivisionv3(imgs)
64
+ loss = loss_fn(output, labels)
65
+ loss.backward()
66
+ optimizer.step()
67
+
68
+ total_train_loss += loss.item() * len(imgs)
69
+
70
+ epoch_train_loss = total_train_loss / train_datasize
71
+ print(f"Train epoch loss: {epoch_train_loss:.2f}")
72
+ writer.add_scalar("Train Loss", epoch_train_loss, i)
73
+
74
+ minivisionv3.eval()
75
+ total_test_loss = 0
76
+ total_accuracy = 0
77
+ with torch.no_grad():
78
+ for data in tqdm(dataloader_test, file=sys.stdout):
79
+ imgs, labels = data
80
+ output = minivisionv3(imgs)
81
+ loss = loss_fn(output, labels)
82
+
83
+ total_test_loss += loss.item() * len(imgs)
84
+ accuracy = (output.argmax(1) == labels).sum().item()
85
+ total_accuracy += accuracy
86
+ epoch_test_loss = round(total_test_loss / test_datasize, 2)
87
+ print(f"Test epoch loss: {epoch_test_loss}")
88
+ writer.add_scalar("Test Loss", epoch_test_loss, i)
89
+
90
+ total_accuracy_percentage = round((total_accuracy / test_datasize) * 100, 2)
91
+ print(f"Test accuracy percentage: {total_accuracy_percentage}%")
92
+ writer.add_scalar("Test Accuracy", total_accuracy_percentage, i)
93
+
94
+ scheduler.step()
95
+ torch.save(minivisionv3.state_dict(), f"{save_folder}/MiniVisionV3_Epoch{i}.pth")