File size: 3,976 Bytes
c3f9ef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import sys
import torchvision
from model import *
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

# dir configs
save_dir = "mini-vision"
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

# visualization
writer = SummaryWriter("mini-vision-logs")

# training config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batchsize = 256
learning_rate = 7e-3

# dataset preprocessing
train_transforms = torchvision.transforms.Compose([torchvision.transforms.RandomCrop(32, 4),
                                                   torchvision.transforms.RandomHorizontalFlip(),
                                                   torchvision.transforms.ToTensor()])

# dataset
train_data = torchvision.datasets.CIFAR10("CIFAR10", True, train_transforms,
                                          download=True)
test_data = torchvision.datasets.CIFAR10("CIFAR10", False, torchvision.transforms.ToTensor(),
                                          download=True)

# dataset length
train_data_size = len(train_data)
test_data_size = len(test_data)
print(train_data_size)
print(test_data_size)

# load dataset
train_dataloader = DataLoader(train_data, batchsize, True)
test_dataloader = DataLoader(test_data, batchsize, False)

# create model
mynetwork = MyNetwork().to(device)

# loss function
loss_fn = nn.CrossEntropyLoss().to(device)

# optimizer
optimizer = torch.optim.SGD(mynetwork.parameters(), learning_rate, 0.9)
schedular = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.5)

# training records
# record train step
total_train_step = 0
# record test step
total_test_step = 0
# training epochs
epoch = 100



for i in range(epoch):
    print(f"---------------Epoch {i + 1} start, LR:{optimizer.param_groups[0]["lr"]}---------------")
    # start training
    mynetwork.train()
    total_train_loss = 0
    print("Training Progress: ", flush=True)
    for data in tqdm(train_dataloader, file=sys.stdout):
        imgs, targets = data
        imgs = imgs.to(device)
        targets = targets.to(device)

        output = mynetwork(imgs)
        loss = loss_fn(output, targets)

        # optim model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        total_train_step += 1
        writer.add_scalar("train_loss", loss.item(), total_train_step + 1)
    train_loss_num = train_data_size / batchsize
    total_train_loss /= train_loss_num
    print(f"Total avg loss on train data: {total_train_loss:.2f}", flush=True)

    # start testing
    mynetwork.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        print("Testing Progress", flush=True)
        for data in tqdm(test_dataloader, file=sys.stdout):
            imgs, targets = data
            imgs = imgs.to(device)
            targets = targets.to(device)

            output = mynetwork(imgs)
            loss = loss_fn(output, targets)
            total_test_loss += loss.item()
            accuracy = (output.argmax(1) == targets).sum()
            total_accuracy += accuracy

        accuracy_percentage = round(float(total_accuracy / test_data_size * 100), 2)
        test_loss_num = test_data_size / batchsize
        total_test_loss /= test_loss_num
        print(f"Total avg loss on test data: {total_test_loss:.2f}", flush=True)
        print(f"Accuracy on test data: {accuracy_percentage}%", flush=True)
        writer.add_scalar("test_loss", total_test_loss, total_test_step + 1)
        writer.add_scalar("test_accuracy", accuracy_percentage, total_test_step + 1)
        total_test_step += 1

    schedular.step()
    torch.save(mynetwork, f"{save_dir}/Mini-Vision-V1{i + 1}.pth")     # save every epoch
    print("Model saved", flush=True)

writer.close()