File size: 3,427 Bytes
5b6d90c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import torch
import sys
from torch import nn
import torchvision
from datasets import load_dataset
from torch.utils.data import DataLoader
from model import MiniVisionV2
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm


save_path = "minivisionv2_model"
batchsize = 256
learningrate = 1e-2
epoch = 50

if not os.path.exists(save_path):
    os.mkdir(save_path)

writer = SummaryWriter("minivisionv2_logs")

dataset = load_dataset("ylecun/mnist")

transform_train = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(28, 2),
    torchvision.transforms.RandomRotation(10),
    torchvision.transforms.ToTensor()
])
transform_test = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])

def transforms_train(data):
    data["tensor"] = [transform_train(img) for img in data["image"]]
    return data

def transforms_test(data):
    data["tensor"] = [transform_test(img) for img in data["image"]]
    return data

train_dataset = dataset["train"].with_transform(transforms_train)
test_dataset = dataset["test"].with_transform(transforms_test)

def collate_fn(batch):
    return {
        "tensor": torch.stack([x["tensor"] for x in batch]),
        "label": torch.tensor([x["label"] for x in batch])
    }

train_loader = DataLoader(train_dataset, batchsize, True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batchsize, False, collate_fn=collate_fn)


minivisionv2 = MiniVisionV2()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(minivisionv2.parameters(), learningrate, 0.8)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, 0.5)

for i in range(epoch):
    print(f"=============== Epoch {i} Start | LR: {optimizer.param_groups[0]["lr"]} ===============")

    minivisionv2.train()
    total_train_loss = 0
    for data in tqdm(train_loader, file=sys.stdout):
        optimizer.zero_grad()
        imgs = data["tensor"]
        labels = data["label"]
        output = minivisionv2(imgs)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
    total_avg_train_loss = total_train_loss / len(train_loader)
    print(f"Train loss: {total_avg_train_loss}")
    writer.add_scalar("Train Loss", total_avg_train_loss, i)

    minivisionv2.eval()
    with torch.no_grad():

        total_accuracy = 0
        total_test_loss = 0
        for data in tqdm(test_loader, file=sys.stdout):
            imgs = data["tensor"]
            labels = data["label"]
            output = minivisionv2(imgs)
            loss = loss_fn(output, labels)
            total_test_loss += loss
            accuracy = (output.argmax(1) == labels).sum()
            total_accuracy += accuracy.item()

        total_avg_test_loss = total_test_loss / len(test_loader)
        total_accuracy_percentage = round(float(total_accuracy / len(test_dataset) * 100), 2)
        print(f"Test loss: {total_avg_test_loss}")
        print(f"Test Accuracy Percentage: {total_accuracy_percentage}%")
        writer.add_scalar("Test Loss", total_avg_test_loss, i)
        writer.add_scalar("Test Accuracy Percentage", total_accuracy_percentage, i)

    torch.save(minivisionv2, f"./{save_path}/Mini-Vision-V2-Epoch-{i}.pth")
    print("Model Saved!")
    scheduler.step()

writer.close()