| | from __future__ import print_function |
| | import argparse |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.optim as optim |
| | from torchvision import datasets, transforms |
| | import numpy as np |
| | from PIL import Image |
| | import CNNmodel |
| |
|
| | torch.manual_seed(100) |
| | device = torch.device("cuda") |
| |
|
| | train_loader = torch.utils.data.DataLoader( |
| | datasets.MNIST('../data', train=True, download=True, |
| | transform=transforms.Compose([transforms.ToTensor(), |
| | transforms.Normalize((0.1307,), (0.3081,))])), |
| | batch_size=64, |
| | shuffle=True) |
| |
|
| | test_loader = torch.utils.data.DataLoader( |
| | datasets.MNIST('../data', train=False, |
| | transform=transforms.Compose([transforms.ToTensor(), |
| | transforms.Normalize((0.1307,), (0.3081,))])), |
| | batch_size=1000, |
| | shuffle=True) |
| |
|
| | model = CNNmodel.Net().to(device) |
| | optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) |
| |
|
| | save_model = True |
| | for epoch in range(1, 5 + 1): |
| | print(epoch) |
| | CNNmodel.train(model, device, train_loader, optimizer, epoch) |
| | CNNmodel.test(model, device, test_loader) |
| |
|
| | if (save_model): |
| | torch.save(model.state_dict(), "mnist_cnn.pt") |
| |
|