Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from NeuralNet import NeuralNet | |
| # Device Config | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # hyper parameters | |
| input_size = 784 # 28*28 | |
| hidden_size = 100 | |
| num_classes = 10 | |
| num_epochs = 20 | |
| batch_size = 500 | |
| learning_rate = 0.001 | |
| # MNIST | |
| training_dataset = torchvision.datasets.MNIST(root='./data', train=True, | |
| transform=transforms.ToTensor(), download=True) | |
| test_dataset = torchvision.datasets.MNIST(root='./data', train=False, | |
| transform=transforms.ToTensor()) | |
| train_loader = torch.utils.data.DataLoader(dataset=training_dataset, batch_size=batch_size, shuffle=True) | |
| test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) | |
| example = iter(train_loader) | |
| samples, labels = next(example) | |
| print(samples.shape, labels.shape) | |
| # for i in range(6): | |
| # plt.subplot(2, 3, i+1) | |
| # plt.imshow(samples[i][0], cmap='gray') | |
| # plt.show() | |
| model = NeuralNet(input_size, hidden_size, num_classes) | |
| #loss and optimizer | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) | |
| #training loop | |
| n_total_steps = len(train_loader) | |
| for epoch in range(num_epochs): | |
| for i, (images, labels) in enumerate(train_loader): | |
| # 100, 1, 28, 28 | |
| # n, c, h, w | |
| images = images.reshape(-1, 28*28).to(device) | |
| labels = labels.to(device) | |
| #forward | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| #backward | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if (i+1) % 100 == 0: | |
| print(f'epoch {epoch+1}/{num_epochs}, step {i+1}/{n_total_steps}, loss = {loss.item():.4f}') | |
| # test | |
| with torch.no_grad(): | |
| n_correct = 0 | |
| n_samples = 0 | |
| for images , labels in test_loader: | |
| images = images.reshape(-1, 28*28).to(device) | |
| labels = labels.to(device) | |
| outputs = model(images) | |
| # value, index | |
| _, predictions = torch.max(outputs, 1) | |
| n_samples += labels.shape[0] | |
| n_correct += (predictions == labels).sum().item() | |
| acc = 100.0 * n_correct / n_samples | |
| print(f'accuracy = {acc}') | |
| torch.save(model.state_dict(), 'model/model.pt') | |