Spaces:
Sleeping
Sleeping
File size: 1,462 Bytes
4027be6 34b3108 4027be6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import torch
from torch import nn
from torchvision import transforms
class MnistModel(nn.Module):
classes = ['0 - zero',
'1 - one',
'2 - two',
'3 - three',
'4 - four',
'5 - five',
'6 - six',
'7 - seven',
'8 - eight',
'9 - nine']
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.conv1 = nn.Conv2d(1, 3, 3)
self.conv2 = nn.Conv2d(3, 6, 3)
self.maxpool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(150, 32)
self.fc2 = nn.Linear(32, 10)
#self.fc3 = nn.Linear(32, 10)
self.dropout = nn.Dropout(0.3)
def forward(self, x):
l1 = nn.ReLU()(self.conv1(x))
l1 = self.maxpool(l1)
l2 = nn.ReLU()(self.conv2(l1))
l2 = self.maxpool(l2)
fc = torch.flatten(l2, 1)
fc1 = nn.ReLU()(self.fc1(fc))
fc1 = self.dropout(fc1)
#fc2 = nn.ReLU()(self.fc2(fc1))
out = self.fc2(fc1)
return out
def load_model():
model = MnistModel()
transforming = transforms.Compose([
transforms.Resize((28,28)),
transforms.ToTensor(),
transforms.Grayscale(num_output_channels=1)
])
model.load_state_dict(torch.load('best_model.pth',map_location='cpu'))
return model,transforming,model.classes
if __name__=='__main__':
pass |