""" This file creates a simple lenet network using the MNIST dataset. """ import random import torch from torchvision import datasets, transforms import torch.nn as nn import torch.nn.functional as F # Download the MNIST Dataset def get_mnist_dataset(): transform = transforms.ToTensor() train_set = datasets.MNIST(root='./data', train=True, transform=transform, download=True) test_set = datasets.MNIST(root='./data', train=False, transform=transform, download=True) return train_set, test_set # Create the lenet model class Classifier(torch.nn.Module): def __init__(self): super().__init__() self.network = nn.Sequential( nn.Conv2d(1, 32, 5), # 28 -> 24 nn.ReLU(), nn.MaxPool2d(2, 2), # 24 -> 12 nn.Conv2d(32, 32, 5), # 12 -> 8 nn.ReLU(), nn.MaxPool2d(2, 2), # 8 -> 4 nn.Flatten(), nn.Linear(32*4*4, 100), nn.ReLU(), nn.Linear(100, 100), nn.ReLU(), nn.Linear(100, 10) ) def forward(self, x): return self.network(x) # Compute accuracy function def compute_accuracy(model, data_set, nb_samples): nb_valid = 0 for it in range(nb_samples): # get a sample sample_idx = torch.randint(len(data_set), size=(1,)).item() img, label = data_set[sample_idx] # compute the output x = torch.reshape(img, (1,1,28,28)) y_h = model.forward(x) pred_label = torch.argmax(y_h).item() if label == pred_label : nb_valid = nb_valid + 1 return nb_valid / nb_samples # Train the model def train_model(NB_ITERATION, CHECK_PERIOD, train_set, test_set, classifier): accuracy_history = [] for it in range(NB_ITERATION): sample_idx = random.randint(0, len(train_set)-1) img, label = train_set[sample_idx] x = torch.flatten(img) x = torch.reshape(x, (1,1,28,28)) y = torch.zeros(1,10) y[0][label] = 1 y_h = classifier.forward(x) #print(y_h.shape, 'test') l = F.mse_loss(y, y_h) l.backward() for p in classifier.parameters(): with torch.no_grad(): p -= 0.01 * p.grad p.grad.zero_() if it % CHECK_PERIOD == 0: accuracy = compute_accuracy(classifier, test_set, CHECK_PERIOD) accuracy_history.append(accuracy) print(f'it {it}: accuracy = {accuracy:.8f} ') def create_lenet(): # Get Dataset train_set, test_set = get_mnist_dataset() # Create model classifier = Classifier() # Train model NB_ITERATION = 50000 CHECK_PERIOD = 3000 print("NB_ITERATIONS = ", NB_ITERATION) print("CHECK_PERIOD = ", CHECK_PERIOD) print("\nTraining LeNet...") train_model(NB_ITERATION, CHECK_PERIOD, train_set, test_set, classifier) # Export as ONNX x = torch.Tensor(1,1,28,28) torch.onnx.export(classifier.network, x, 'lenet.onnx', verbose=False, input_names=[ "input" ], output_names=[ "output" ])