import os import sys import torchvision from model import * import torch.nn as nn from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm # dir configs save_dir = "mini-vision" if not os.path.exists(save_dir): os.mkdir(save_dir) # visualization writer = SummaryWriter("mini-vision-logs") # training config device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batchsize = 256 learning_rate = 7e-3 # dataset preprocessing train_transforms = torchvision.transforms.Compose([torchvision.transforms.RandomCrop(32, 4), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ToTensor()]) # dataset train_data = torchvision.datasets.CIFAR10("CIFAR10", True, train_transforms, download=True) test_data = torchvision.datasets.CIFAR10("CIFAR10", False, torchvision.transforms.ToTensor(), download=True) # dataset length train_data_size = len(train_data) test_data_size = len(test_data) print(train_data_size) print(test_data_size) # load dataset train_dataloader = DataLoader(train_data, batchsize, True) test_dataloader = DataLoader(test_data, batchsize, False) # create model mynetwork = MyNetwork().to(device) # loss function loss_fn = nn.CrossEntropyLoss().to(device) # optimizer optimizer = torch.optim.SGD(mynetwork.parameters(), learning_rate, 0.9) schedular = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.5) # training records # record train step total_train_step = 0 # record test step total_test_step = 0 # training epochs epoch = 100 for i in range(epoch): print(f"---------------Epoch {i + 1} start, LR:{optimizer.param_groups[0]["lr"]}---------------") # start training mynetwork.train() total_train_loss = 0 print("Training Progress: ", flush=True) for data in tqdm(train_dataloader, file=sys.stdout): imgs, targets = data imgs = imgs.to(device) targets = targets.to(device) output = mynetwork(imgs) loss = loss_fn(output, targets) # optim model optimizer.zero_grad() loss.backward() optimizer.step() total_train_loss += loss.item() total_train_step += 1 writer.add_scalar("train_loss", loss.item(), total_train_step + 1) train_loss_num = train_data_size / batchsize total_train_loss /= train_loss_num print(f"Total avg loss on train data: {total_train_loss:.2f}", flush=True) # start testing mynetwork.eval() total_test_loss = 0 total_accuracy = 0 with torch.no_grad(): print("Testing Progress", flush=True) for data in tqdm(test_dataloader, file=sys.stdout): imgs, targets = data imgs = imgs.to(device) targets = targets.to(device) output = mynetwork(imgs) loss = loss_fn(output, targets) total_test_loss += loss.item() accuracy = (output.argmax(1) == targets).sum() total_accuracy += accuracy accuracy_percentage = round(float(total_accuracy / test_data_size * 100), 2) test_loss_num = test_data_size / batchsize total_test_loss /= test_loss_num print(f"Total avg loss on test data: {total_test_loss:.2f}", flush=True) print(f"Accuracy on test data: {accuracy_percentage}%", flush=True) writer.add_scalar("test_loss", total_test_loss, total_test_step + 1) writer.add_scalar("test_accuracy", accuracy_percentage, total_test_step + 1) total_test_step += 1 schedular.step() torch.save(mynetwork, f"{save_dir}/Mini-Vision-V1{i + 1}.pth") # save every epoch print("Model saved", flush=True) writer.close()