Yaning1001's picture
Add files using upload-large-folder tool
c91d7b1 verified
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F #233
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from PIL import Image
import CNNmodel
torch.manual_seed(100)
device = torch.device("cuda")
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=64,
shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=1000,
shuffle=True)
model = CNNmodel.Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
save_model = True
for epoch in range(1, 5 + 1): ## 5 batches
print(epoch)
CNNmodel.train(model, device, train_loader, optimizer, epoch)
CNNmodel.test(model, device, test_loader)
if (save_model):
torch.save(model.state_dict(), "mnist_cnn.pt")