File size: 3,339 Bytes
68ea1b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import os
import sys
import torch
import torchvision
from model import MiniVisionV3
from torch.utils.data import DataLoader
from torchvision.transforms import functional as F
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
# Global config
epoch = 50
learningrate = 5e-2
batchsize = 256
save_folder = "MiniVisionV3"
writer = SummaryWriter("../MiniVisionV3_log")
if not os.path.exists(save_folder):
os.mkdir(save_folder)
transform_correct_train = torchvision.transforms.Compose([
torchvision.transforms.Lambda(lambda x: F.rotate(x, -90)),
F.hflip,
torchvision.transforms.RandomCrop(28, 2),
torchvision.transforms.RandomRotation(10),
torchvision.transforms.ToTensor()
])
transform_correct_test = torchvision.transforms.Compose([
torchvision.transforms.Lambda(lambda x: F.rotate(x, -90)),
F.hflip,
torchvision.transforms.ToTensor()
])
dataset_train = torchvision.datasets.EMNIST("../EMNIST_train", "balanced", train=True, download=True, transform=transform_correct_train)
dataset_test = torchvision.datasets.EMNIST("../EMNIST_test", "balanced", train=False, download=True, transform=transform_correct_test)
dataloader_train = DataLoader(dataset_train, batchsize, True)
dataloader_test = DataLoader(dataset_test, batchsize, False)
minivisionv3 = MiniVisionV3()
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(minivisionv3.parameters(), lr=learningrate, momentum=0.8)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.5)
train_datasize = len(dataset_train)
test_datasize = len(dataset_test)
print(f"Train dataset size: {train_datasize}")
print(f"test dataset size: {test_datasize}")
for i in range(epoch):
print(f"=============== Epoch {i} Start | LR: {optimizer.param_groups[0]['lr']} ===============")
minivisionv3.train()
total_train_loss = 0
for data in tqdm(dataloader_train,file=sys.stdout):
optimizer.zero_grad()
imgs, labels = data
output = minivisionv3(imgs)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()
total_train_loss += loss.item() * len(imgs)
epoch_train_loss = total_train_loss / train_datasize
print(f"Train epoch loss: {epoch_train_loss:.2f}")
writer.add_scalar("Train Loss", epoch_train_loss, i)
minivisionv3.eval()
total_test_loss = 0
total_accuracy = 0
with torch.no_grad():
for data in tqdm(dataloader_test, file=sys.stdout):
imgs, labels = data
output = minivisionv3(imgs)
loss = loss_fn(output, labels)
total_test_loss += loss.item() * len(imgs)
accuracy = (output.argmax(1) == labels).sum().item()
total_accuracy += accuracy
epoch_test_loss = round(total_test_loss / test_datasize, 2)
print(f"Test epoch loss: {epoch_test_loss}")
writer.add_scalar("Test Loss", epoch_test_loss, i)
total_accuracy_percentage = round((total_accuracy / test_datasize) * 100, 2)
print(f"Test accuracy percentage: {total_accuracy_percentage}%")
writer.add_scalar("Test Accuracy", total_accuracy_percentage, i)
scheduler.step()
torch.save(minivisionv3.state_dict(), f"{save_folder}/MiniVisionV3_Epoch{i}.pth")
|