File size: 4,274 Bytes
3c79a02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#Import all modules needed for Pytorch

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

def isCUDAAvailable():
  return torch.cuda.is_available()

def getDevice():
    device = torch.device("cuda" if isCUDAAvailable() else "cpu")
    return device
    
def plotData(loader, count, cmap_code):
    import matplotlib.pyplot as plt
    batch_data, batch_label = next(iter(loader))
    fig = plt.figure()
    
    for i in range(count):
        plt.subplot(3,4,i+1)
        plt.tight_layout()
        plt.imshow(batch_data[i].squeeze(0), cmap=cmap_code)
        plt.title(batch_label[i].item())
        plt.xticks([])
        plt.yticks([])
        
def getTrainTransforms_CropRotate(centerCrop, resize, randomRotate,mean,std_dev):    
    # Train data transformations
    train_transforms = transforms.Compose([
        transforms.RandomApply([transforms.CenterCrop(centerCrop), ], p=0.1),
        transforms.Resize((resize, resize)),
        transforms.RandomRotation((-randomRotate, randomRotate), fill=0),
        transforms.ToTensor(),
        transforms.Normalize((mean,), (std_dev,)),
        ])
    return train_transforms

def getTrainTransforms(mean,std_dev):    
    # Train data transformations
    train_transforms = transforms.Compose([        
        transforms.ToTensor(),
        transforms.Normalize((mean,), (std_dev,)),
        ])
    return train_transforms

def getTestTransforms(mean,std_dev):
    # Test data transformations
    test_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((mean,), (std_dev,))
        ])
    return test_transforms


def GetCorrectPredCount(pPrediction, pLabels):
    return pPrediction.argmax(dim=1).eq(pLabels).sum().item()

def train(model, train_loader, optimizer, criterion, scheduler):
    from tqdm import tqdm
    model.train()
    pbar = tqdm(train_loader)
    device = getDevice()
    train_loss = 0
    correct = 0
    processed = 0
    train_acc = []
    train_losses = []
    
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        # Predict
        pred = model(data)

        # Calculate loss
        loss = criterion(pred, target)
        train_loss+=loss.item()

        # Backpropagation
        loss.backward()
        optimizer.step()
    
        correct += GetCorrectPredCount(pred, target)
        processed += len(data)

        pbar.set_description(desc= f'Train: Loss={loss.item():0.4f} Batch_id={batch_idx} Accuracy={100*correct/processed:0.2f}')
        scheduler.step()

    train_acc.append(100*correct/processed)
    train_losses.append(train_loss/len(train_loader))
    return train_acc, train_losses

def test(model, test_loader, criterion):
    from tqdm import tqdm
    model.eval()
    device = getDevice()
    
    test_loss = 0
    correct = 0
    test_acc = []
    test_losses = []

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)

            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss

            correct += GetCorrectPredCount(output, target)


    test_loss /= len(test_loader.dataset)
    test_acc.append(100. * correct / len(test_loader.dataset))
    test_losses.append(test_loss)

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_acc, test_losses


def printModelSummary(model, inputSize):
    from torchsummary import summary   
    summary(model, input_size=inputSize)
    
def printModelTrainTestAccuracy(train_acc, train_losses, test_acc, test_losses):
    import matplotlib.pyplot as plt
    fig, axs = plt.subplots(2,2,figsize=(15,10))
    axs[0, 0].plot(train_losses)
    axs[0, 0].set_title("Training Loss")
    axs[1, 0].plot(train_acc)
    axs[1, 0].set_title("Training Accuracy")
    axs[0, 1].plot(test_losses)
    axs[0, 1].set_title("Test Loss")
    axs[1, 1].plot(test_acc)
    axs[1, 1].set_title("Test Accuracy")