Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from torchvision import transforms | |
| class MnistModel(nn.Module): | |
| classes = ['0 - zero', | |
| '1 - one', | |
| '2 - two', | |
| '3 - three', | |
| '4 - four', | |
| '5 - five', | |
| '6 - six', | |
| '7 - seven', | |
| '8 - eight', | |
| '9 - nine'] | |
| def __init__(self, *args, **kwargs) -> None: | |
| super().__init__(*args, **kwargs) | |
| self.conv1 = nn.Conv2d(1, 3, 3) | |
| self.conv2 = nn.Conv2d(3, 6, 3) | |
| self.maxpool = nn.MaxPool2d(2, 2) | |
| self.fc1 = nn.Linear(150, 32) | |
| self.fc2 = nn.Linear(32, 10) | |
| #self.fc3 = nn.Linear(32, 10) | |
| self.dropout = nn.Dropout(0.3) | |
| def forward(self, x): | |
| l1 = nn.ReLU()(self.conv1(x)) | |
| l1 = self.maxpool(l1) | |
| l2 = nn.ReLU()(self.conv2(l1)) | |
| l2 = self.maxpool(l2) | |
| fc = torch.flatten(l2, 1) | |
| fc1 = nn.ReLU()(self.fc1(fc)) | |
| fc1 = self.dropout(fc1) | |
| #fc2 = nn.ReLU()(self.fc2(fc1)) | |
| out = self.fc2(fc1) | |
| return out | |
| def load_model(): | |
| model = MnistModel() | |
| transforming = transforms.Compose([ | |
| transforms.Resize((28,28)), | |
| transforms.ToTensor(), | |
| transforms.Grayscale(num_output_channels=1) | |
| ]) | |
| model.load_state_dict(torch.load('best_model.pth',map_location='cpu')) | |
| return model,transforming,model.classes | |
| if __name__=='__main__': | |
| pass |