import os import torch import sys from torch import nn import torchvision from datasets import load_dataset from torch.utils.data import DataLoader from model import MiniVisionV2 from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm save_path = "minivisionv2_model" batchsize = 256 learningrate = 1e-2 epoch = 50 if not os.path.exists(save_path): os.mkdir(save_path) writer = SummaryWriter("minivisionv2_logs") dataset = load_dataset("ylecun/mnist") transform_train = torchvision.transforms.Compose([ torchvision.transforms.RandomCrop(28, 2), torchvision.transforms.RandomRotation(10), torchvision.transforms.ToTensor() ]) transform_test = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), ]) def transforms_train(data): data["tensor"] = [transform_train(img) for img in data["image"]] return data def transforms_test(data): data["tensor"] = [transform_test(img) for img in data["image"]] return data train_dataset = dataset["train"].with_transform(transforms_train) test_dataset = dataset["test"].with_transform(transforms_test) def collate_fn(batch): return { "tensor": torch.stack([x["tensor"] for x in batch]), "label": torch.tensor([x["label"] for x in batch]) } train_loader = DataLoader(train_dataset, batchsize, True, collate_fn=collate_fn) test_loader = DataLoader(test_dataset, batchsize, False, collate_fn=collate_fn) minivisionv2 = MiniVisionV2() loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(minivisionv2.parameters(), learningrate, 0.8) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, 0.5) for i in range(epoch): print(f"=============== Epoch {i} Start | LR: {optimizer.param_groups[0]["lr"]} ===============") minivisionv2.train() total_train_loss = 0 for data in tqdm(train_loader, file=sys.stdout): optimizer.zero_grad() imgs = data["tensor"] labels = data["label"] output = minivisionv2(imgs) loss = loss_fn(output, labels) loss.backward() optimizer.step() total_train_loss += loss.item() total_avg_train_loss = total_train_loss / len(train_loader) print(f"Train loss: {total_avg_train_loss}") writer.add_scalar("Train Loss", total_avg_train_loss, i) minivisionv2.eval() with torch.no_grad(): total_accuracy = 0 total_test_loss = 0 for data in tqdm(test_loader, file=sys.stdout): imgs = data["tensor"] labels = data["label"] output = minivisionv2(imgs) loss = loss_fn(output, labels) total_test_loss += loss accuracy = (output.argmax(1) == labels).sum() total_accuracy += accuracy.item() total_avg_test_loss = total_test_loss / len(test_loader) total_accuracy_percentage = round(float(total_accuracy / len(test_dataset) * 100), 2) print(f"Test loss: {total_avg_test_loss}") print(f"Test Accuracy Percentage: {total_accuracy_percentage}%") writer.add_scalar("Test Loss", total_avg_test_loss, i) writer.add_scalar("Test Accuracy Percentage", total_accuracy_percentage, i) torch.save(minivisionv2, f"./{save_path}/Mini-Vision-V2-Epoch-{i}.pth") print("Model Saved!") scheduler.step() writer.close()