File size: 3,339 Bytes
68ea1b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import sys
import torch
import torchvision
from model import MiniVisionV3
from torch.utils.data import DataLoader
from torchvision.transforms import functional as F
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter


# Global config
epoch = 50
learningrate = 5e-2
batchsize = 256


save_folder = "MiniVisionV3"
writer = SummaryWriter("../MiniVisionV3_log")

if not os.path.exists(save_folder):
    os.mkdir(save_folder)

transform_correct_train = torchvision.transforms.Compose([
    torchvision.transforms.Lambda(lambda x: F.rotate(x, -90)),
    F.hflip,
    torchvision.transforms.RandomCrop(28, 2),
    torchvision.transforms.RandomRotation(10),
    torchvision.transforms.ToTensor()
])
transform_correct_test = torchvision.transforms.Compose([
    torchvision.transforms.Lambda(lambda x: F.rotate(x, -90)),
    F.hflip,
    torchvision.transforms.ToTensor()
])

dataset_train = torchvision.datasets.EMNIST("../EMNIST_train", "balanced", train=True, download=True, transform=transform_correct_train)
dataset_test = torchvision.datasets.EMNIST("../EMNIST_test", "balanced", train=False, download=True, transform=transform_correct_test)

dataloader_train = DataLoader(dataset_train, batchsize, True)
dataloader_test = DataLoader(dataset_test, batchsize, False)


minivisionv3 = MiniVisionV3()

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(minivisionv3.parameters(), lr=learningrate, momentum=0.8)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.5)

train_datasize = len(dataset_train)
test_datasize = len(dataset_test)
print(f"Train dataset size: {train_datasize}")
print(f"test dataset size: {test_datasize}")

for i in range(epoch):
    print(f"=============== Epoch {i} Start | LR: {optimizer.param_groups[0]['lr']} ===============")

    minivisionv3.train()
    total_train_loss = 0
    for data in tqdm(dataloader_train,file=sys.stdout):
        optimizer.zero_grad()
        imgs, labels = data
        output = minivisionv3(imgs)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item() * len(imgs)

    epoch_train_loss = total_train_loss / train_datasize
    print(f"Train epoch loss: {epoch_train_loss:.2f}")
    writer.add_scalar("Train Loss", epoch_train_loss, i)

    minivisionv3.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in tqdm(dataloader_test, file=sys.stdout):
            imgs, labels = data
            output = minivisionv3(imgs)
            loss = loss_fn(output, labels)

            total_test_loss += loss.item() * len(imgs)
            accuracy = (output.argmax(1) == labels).sum().item()
            total_accuracy += accuracy
    epoch_test_loss = round(total_test_loss / test_datasize, 2)
    print(f"Test epoch loss: {epoch_test_loss}")
    writer.add_scalar("Test Loss", epoch_test_loss, i)

    total_accuracy_percentage = round((total_accuracy / test_datasize) * 100, 2)
    print(f"Test accuracy percentage: {total_accuracy_percentage}%")
    writer.add_scalar("Test Accuracy", total_accuracy_percentage, i)

    scheduler.step()
    torch.save(minivisionv3.state_dict(), f"{save_folder}/MiniVisionV3_Epoch{i}.pth")