File size: 2,211 Bytes
2dfdcd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import time
import torch as tr
import torchvision as tv
import torchvision.transforms as transforms
from model import model, loss_fn, optimizer

def main():
    batch_size = 128
    epochs = 10
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor()])
    
    trn_dataset = tv.datasets.ImageFolder(
        root='data/tiny-imagenet-200/train',
        transform=transform)
    
    evl_dataset = tv.datasets.ImageFolder(
        root='data/tiny-imagenet-200/val',
        transform=transform)
    
    trn_loader = tr.utils.data.DataLoader(
        trn_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True)
    
    evl_loader = tr.utils.data.DataLoader(
        evl_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True)
    
    device = tr.device("cuda" if tr.cuda.is_available() else "cpu")
    model.to(device)
    print("Running on:", device)
    start_time = time.time()
    for epoch in range(1, epochs + 1):
        model.train()
        for images, labels in trn_loader:
            images = images.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = loss_fn(tr.log(tr.clamp(output, min=1e-9)), labels)
            loss.backward()
            optimizer.step()
        model.eval()
        correct = 0
        total = 0

        with tr.no_grad():
            for images, labels in evl_loader:
                images = images.to(device)
                labels = labels.to(device)
                output = model(images)
                preds = tr.argmax(output, dim=1)
                correct += tr.sum(preds == labels).item()
                total += labels.size(0)
        accuracy = correct / total
        print(f"Epoch {epoch}/{epochs} - Loss: {loss.item():.4f} - Accuracy: {accuracy*100:.2f}%")

    end_time = time.time()

    print("Training finished")
    print(f"Final Loss: {loss.item():.4f}")
    print(f"Final Accuracy: {accuracy*100:.2f}%")
    print(f"Total Time: {end_time - start_time:.2f} seconds")


if __name__ == "__main__":
    main()