import os import sys import torch import torchvision from model import MiniVisionV3 from torch.utils.data import DataLoader from torchvision.transforms import functional as F from tqdm import tqdm from torch.utils.tensorboard import SummaryWriter # Global config epoch = 50 learningrate = 5e-2 batchsize = 256 save_folder = "MiniVisionV3" writer = SummaryWriter("../MiniVisionV3_log") if not os.path.exists(save_folder): os.mkdir(save_folder) transform_correct_train = torchvision.transforms.Compose([ torchvision.transforms.Lambda(lambda x: F.rotate(x, -90)), F.hflip, torchvision.transforms.RandomCrop(28, 2), torchvision.transforms.RandomRotation(10), torchvision.transforms.ToTensor() ]) transform_correct_test = torchvision.transforms.Compose([ torchvision.transforms.Lambda(lambda x: F.rotate(x, -90)), F.hflip, torchvision.transforms.ToTensor() ]) dataset_train = torchvision.datasets.EMNIST("../EMNIST_train", "balanced", train=True, download=True, transform=transform_correct_train) dataset_test = torchvision.datasets.EMNIST("../EMNIST_test", "balanced", train=False, download=True, transform=transform_correct_test) dataloader_train = DataLoader(dataset_train, batchsize, True) dataloader_test = DataLoader(dataset_test, batchsize, False) minivisionv3 = MiniVisionV3() loss_fn = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(minivisionv3.parameters(), lr=learningrate, momentum=0.8) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.5) train_datasize = len(dataset_train) test_datasize = len(dataset_test) print(f"Train dataset size: {train_datasize}") print(f"test dataset size: {test_datasize}") for i in range(epoch): print(f"=============== Epoch {i} Start | LR: {optimizer.param_groups[0]['lr']} ===============") minivisionv3.train() total_train_loss = 0 for data in tqdm(dataloader_train,file=sys.stdout): optimizer.zero_grad() imgs, labels = data output = minivisionv3(imgs) loss = loss_fn(output, labels) loss.backward() optimizer.step() total_train_loss += loss.item() * len(imgs) epoch_train_loss = total_train_loss / train_datasize print(f"Train epoch loss: {epoch_train_loss:.2f}") writer.add_scalar("Train Loss", epoch_train_loss, i) minivisionv3.eval() total_test_loss = 0 total_accuracy = 0 with torch.no_grad(): for data in tqdm(dataloader_test, file=sys.stdout): imgs, labels = data output = minivisionv3(imgs) loss = loss_fn(output, labels) total_test_loss += loss.item() * len(imgs) accuracy = (output.argmax(1) == labels).sum().item() total_accuracy += accuracy epoch_test_loss = round(total_test_loss / test_datasize, 2) print(f"Test epoch loss: {epoch_test_loss}") writer.add_scalar("Test Loss", epoch_test_loss, i) total_accuracy_percentage = round((total_accuracy / test_datasize) * 100, 2) print(f"Test accuracy percentage: {total_accuracy_percentage}%") writer.add_scalar("Test Accuracy", total_accuracy_percentage, i) scheduler.step() torch.save(minivisionv3.state_dict(), f"{save_folder}/MiniVisionV3_Epoch{i}.pth")