| import os
|
| import sys
|
| import torch
|
| import torchvision
|
| from model import MiniVisionV3
|
| from torch.utils.data import DataLoader
|
| from torchvision.transforms import functional as F
|
| from tqdm import tqdm
|
| from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
|
| epoch = 50
|
| learningrate = 5e-2
|
| batchsize = 256
|
|
|
|
|
| save_folder = "MiniVisionV3"
|
| writer = SummaryWriter("../MiniVisionV3_log")
|
|
|
| if not os.path.exists(save_folder):
|
| os.mkdir(save_folder)
|
|
|
| transform_correct_train = torchvision.transforms.Compose([
|
| torchvision.transforms.Lambda(lambda x: F.rotate(x, -90)),
|
| F.hflip,
|
| torchvision.transforms.RandomCrop(28, 2),
|
| torchvision.transforms.RandomRotation(10),
|
| torchvision.transforms.ToTensor()
|
| ])
|
| transform_correct_test = torchvision.transforms.Compose([
|
| torchvision.transforms.Lambda(lambda x: F.rotate(x, -90)),
|
| F.hflip,
|
| torchvision.transforms.ToTensor()
|
| ])
|
|
|
| dataset_train = torchvision.datasets.EMNIST("../EMNIST_train", "balanced", train=True, download=True, transform=transform_correct_train)
|
| dataset_test = torchvision.datasets.EMNIST("../EMNIST_test", "balanced", train=False, download=True, transform=transform_correct_test)
|
|
|
| dataloader_train = DataLoader(dataset_train, batchsize, True)
|
| dataloader_test = DataLoader(dataset_test, batchsize, False)
|
|
|
|
|
| minivisionv3 = MiniVisionV3()
|
|
|
| loss_fn = torch.nn.CrossEntropyLoss()
|
| optimizer = torch.optim.SGD(minivisionv3.parameters(), lr=learningrate, momentum=0.8)
|
| scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.5)
|
|
|
| train_datasize = len(dataset_train)
|
| test_datasize = len(dataset_test)
|
| print(f"Train dataset size: {train_datasize}")
|
| print(f"test dataset size: {test_datasize}")
|
|
|
| for i in range(epoch):
|
| print(f"=============== Epoch {i} Start | LR: {optimizer.param_groups[0]['lr']} ===============")
|
|
|
| minivisionv3.train()
|
| total_train_loss = 0
|
| for data in tqdm(dataloader_train,file=sys.stdout):
|
| optimizer.zero_grad()
|
| imgs, labels = data
|
| output = minivisionv3(imgs)
|
| loss = loss_fn(output, labels)
|
| loss.backward()
|
| optimizer.step()
|
|
|
| total_train_loss += loss.item() * len(imgs)
|
|
|
| epoch_train_loss = total_train_loss / train_datasize
|
| print(f"Train epoch loss: {epoch_train_loss:.2f}")
|
| writer.add_scalar("Train Loss", epoch_train_loss, i)
|
|
|
| minivisionv3.eval()
|
| total_test_loss = 0
|
| total_accuracy = 0
|
| with torch.no_grad():
|
| for data in tqdm(dataloader_test, file=sys.stdout):
|
| imgs, labels = data
|
| output = minivisionv3(imgs)
|
| loss = loss_fn(output, labels)
|
|
|
| total_test_loss += loss.item() * len(imgs)
|
| accuracy = (output.argmax(1) == labels).sum().item()
|
| total_accuracy += accuracy
|
| epoch_test_loss = round(total_test_loss / test_datasize, 2)
|
| print(f"Test epoch loss: {epoch_test_loss}")
|
| writer.add_scalar("Test Loss", epoch_test_loss, i)
|
|
|
| total_accuracy_percentage = round((total_accuracy / test_datasize) * 100, 2)
|
| print(f"Test accuracy percentage: {total_accuracy_percentage}%")
|
| writer.add_scalar("Test Accuracy", total_accuracy_percentage, i)
|
|
|
| scheduler.step()
|
| torch.save(minivisionv3.state_dict(), f"{save_folder}/MiniVisionV3_Epoch{i}.pth")
|
|
|