File size: 4,274 Bytes
3c79a02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
#Import all modules needed for Pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
def isCUDAAvailable():
return torch.cuda.is_available()
def getDevice():
device = torch.device("cuda" if isCUDAAvailable() else "cpu")
return device
def plotData(loader, count, cmap_code):
import matplotlib.pyplot as plt
batch_data, batch_label = next(iter(loader))
fig = plt.figure()
for i in range(count):
plt.subplot(3,4,i+1)
plt.tight_layout()
plt.imshow(batch_data[i].squeeze(0), cmap=cmap_code)
plt.title(batch_label[i].item())
plt.xticks([])
plt.yticks([])
def getTrainTransforms_CropRotate(centerCrop, resize, randomRotate,mean,std_dev):
# Train data transformations
train_transforms = transforms.Compose([
transforms.RandomApply([transforms.CenterCrop(centerCrop), ], p=0.1),
transforms.Resize((resize, resize)),
transforms.RandomRotation((-randomRotate, randomRotate), fill=0),
transforms.ToTensor(),
transforms.Normalize((mean,), (std_dev,)),
])
return train_transforms
def getTrainTransforms(mean,std_dev):
# Train data transformations
train_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((mean,), (std_dev,)),
])
return train_transforms
def getTestTransforms(mean,std_dev):
# Test data transformations
test_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((mean,), (std_dev,))
])
return test_transforms
def GetCorrectPredCount(pPrediction, pLabels):
return pPrediction.argmax(dim=1).eq(pLabels).sum().item()
def train(model, train_loader, optimizer, criterion, scheduler):
from tqdm import tqdm
model.train()
pbar = tqdm(train_loader)
device = getDevice()
train_loss = 0
correct = 0
processed = 0
train_acc = []
train_losses = []
for batch_idx, (data, target) in enumerate(pbar):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# Predict
pred = model(data)
# Calculate loss
loss = criterion(pred, target)
train_loss+=loss.item()
# Backpropagation
loss.backward()
optimizer.step()
correct += GetCorrectPredCount(pred, target)
processed += len(data)
pbar.set_description(desc= f'Train: Loss={loss.item():0.4f} Batch_id={batch_idx} Accuracy={100*correct/processed:0.2f}')
scheduler.step()
train_acc.append(100*correct/processed)
train_losses.append(train_loss/len(train_loader))
return train_acc, train_losses
def test(model, test_loader, criterion):
from tqdm import tqdm
model.eval()
device = getDevice()
test_loss = 0
correct = 0
test_acc = []
test_losses = []
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_loader):
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item() # sum up batch loss
correct += GetCorrectPredCount(output, target)
test_loss /= len(test_loader.dataset)
test_acc.append(100. * correct / len(test_loader.dataset))
test_losses.append(test_loss)
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
return test_acc, test_losses
def printModelSummary(model, inputSize):
from torchsummary import summary
summary(model, input_size=inputSize)
def printModelTrainTestAccuracy(train_acc, train_losses, test_acc, test_losses):
import matplotlib.pyplot as plt
fig, axs = plt.subplots(2,2,figsize=(15,10))
axs[0, 0].plot(train_losses)
axs[0, 0].set_title("Training Loss")
axs[1, 0].plot(train_acc)
axs[1, 0].set_title("Training Accuracy")
axs[0, 1].plot(test_losses)
axs[0, 1].set_title("Test Loss")
axs[1, 1].plot(test_acc)
axs[1, 1].set_title("Test Accuracy")
|