File size: 1,462 Bytes
4027be6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34b3108
4027be6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

import torch
from torch import nn
from torchvision import transforms


class MnistModel(nn.Module):

    classes = ['0 - zero',
                '1 - one',
                '2 - two',
                '3 - three',
                '4 - four',
                '5 - five',
                '6 - six',
                '7 - seven',
                '8 - eight',
                '9 - nine']
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(1, 3, 3)
        self.conv2 = nn.Conv2d(3, 6, 3)
        self.maxpool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(150, 32)
        self.fc2 = nn.Linear(32, 10)
        #self.fc3 = nn.Linear(32, 10)
        self.dropout = nn.Dropout(0.3) 

    def forward(self, x):
        l1 = nn.ReLU()(self.conv1(x))
        l1 = self.maxpool(l1)
        l2 = nn.ReLU()(self.conv2(l1))
        l2 = self.maxpool(l2)
        fc = torch.flatten(l2, 1)
        fc1 = nn.ReLU()(self.fc1(fc))
        fc1 = self.dropout(fc1)  
        #fc2 = nn.ReLU()(self.fc2(fc1))
        out = self.fc2(fc1)
        return out
    

def load_model():
    model = MnistModel()
    transforming = transforms.Compose([
    transforms.Resize((28,28)),
    transforms.ToTensor(),
    transforms.Grayscale(num_output_channels=1)
])
    
    model.load_state_dict(torch.load('best_model.pth',map_location='cpu'))

    return model,transforming,model.classes

if __name__=='__main__':
    pass