Vvaann commited on
Commit
6254023
·
verified ·
1 Parent(s): 7d3a34b

Upload 4 files

Browse files
Files changed (4) hide show
  1. epoch=19-step=3920.ckpt +3 -0
  2. resnet_lightning.py +179 -0
  3. utils.py +298 -0
  4. visualize.py +384 -0
epoch=19-step=3920.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b3c28bc1f895ab927922a707c81ea52aac3bf6c311ba91577f706e997a66f73
3
+ size 89490895
resnet_lightning.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import all the required modules
2
+ import os
3
+ os.environ['KMP_DUPLICATE_LIB_OK']='True'
4
+ import math
5
+ from collections import OrderedDict
6
+ import sys
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.optim as optim
11
+ from torchvision import datasets
12
+
13
+
14
+ import albumentations as A
15
+ from albumentations.pytorch import ToTensorV2
16
+
17
+
18
+ from torch_lr_finder import LRFinder
19
+
20
+ from pytorch_grad_cam import GradCAM
21
+ from utils import *
22
+
23
+ class BasicBlock(nn.Module):
24
+ expansion = 1
25
+
26
+ def __init__(self, in_planes, planes, stride=1):
27
+ super(BasicBlock, self).__init__()
28
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
29
+ self.bn1 = nn.BatchNorm2d(planes)
30
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
31
+ self.bn2 = nn.BatchNorm2d(planes)
32
+
33
+ self.shortcut = nn.Sequential()
34
+ if stride != 1 or in_planes != self.expansion*planes:
35
+ self.shortcut = nn.Sequential(
36
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
37
+ nn.BatchNorm2d(self.expansion*planes)
38
+ )
39
+
40
+ def forward(self, x):
41
+ out = F.relu(self.bn1(self.conv1(x)))
42
+ out = self.bn2(self.conv2(out))
43
+ out += self.shortcut(x)
44
+ out = F.relu(out)
45
+ return out
46
+
47
+
48
+ class ResNet18Model(LightningModule):
49
+ def __init__(self, data_dir=PATH_DATASETS, block=BasicBlock, num_blocks=[2, 2, 2, 2], num_classes=10):
50
+ super(ResNet18Model, self).__init__()
51
+ self.data_dir = data_dir
52
+ self.num_classes = num_classes
53
+
54
+ means = [0.4914, 0.4822, 0.4465]
55
+ stds = [0.2470, 0.2435, 0.2616]
56
+
57
+ self.train_transforms = A.Compose(
58
+ [
59
+ A.Normalize(mean=means, std=stds, always_apply=True),
60
+ A.PadIfNeeded(min_height=36, min_width=36, always_apply=True),
61
+ A.RandomCrop(height=32, width=32, always_apply=True),
62
+ A.HorizontalFlip(),
63
+ A.CoarseDropout(max_holes=1, max_height=16, max_width=16, min_holes=1, min_height=8, min_width=8, fill_value=means),
64
+ ToTensorV2(),
65
+ ]
66
+ )
67
+ self.test_transforms = A.Compose(
68
+ [
69
+ A.Normalize(mean=means, std=stds, always_apply=True),
70
+ ToTensorV2(),
71
+ ]
72
+ )
73
+
74
+ self.in_planes = 64
75
+
76
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
77
+ self.bn1 = nn.BatchNorm2d(64)
78
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
79
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
80
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
81
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
82
+ self.linear = nn.Linear(512*block.expansion, num_classes)
83
+
84
+ self.accuracy = Accuracy(task="multiclass", num_classes=10)
85
+
86
+ def _make_layer(self, block, planes, num_blocks, stride):
87
+ strides = [stride] + [1]*(num_blocks-1)
88
+ layers = []
89
+ for stride in strides:
90
+ layers.append(block(self.in_planes, planes, stride))
91
+ self.in_planes = planes * block.expansion
92
+ return nn.Sequential(*layers)
93
+
94
+ def forward(self, x):
95
+ out = F.relu(self.bn1(self.conv1(x)))
96
+ out = self.layer1(out)
97
+ out = self.layer2(out)
98
+ out = self.layer3(out)
99
+ out = self.layer4(out)
100
+ out = F.avg_pool2d(out, 4)
101
+ out = out.view(out.size(0), -1)
102
+ out = self.linear(out)
103
+ return out
104
+
105
+ def training_step(self, batch, batch_idx):
106
+ x, y = batch
107
+ loss = F.cross_entropy(self(x), y)
108
+ return loss
109
+
110
+ def validation_step(self, batch, batch_idx):
111
+ x, y = batch
112
+ logits = self(x)
113
+ loss = F.nll_loss(logits, y)
114
+ preds = torch.argmax(logits, dim=1)
115
+ self.accuracy(preds, y)
116
+
117
+ # Calling self.log will surface up scalars for you in TensorBoard
118
+ self.log("val_loss", loss, prog_bar=True)
119
+ self.log("val_acc", self.accuracy, prog_bar=True)
120
+ return loss
121
+
122
+ def test_step(self, batch, batch_idx):
123
+ # Here we just reuse the validation_step for testing
124
+ return self.validation_step(batch, batch_idx)
125
+
126
+ def configure_optimizers(self):
127
+ LEARNING_RATE = 0.03
128
+ WEIGHT_DECAY = 1e-4
129
+ # # Loss Function
130
+ # criterion = nn.CrossEntropyLoss()
131
+
132
+ # optimizer = optim.SGD(self.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=WEIGHT_DECAY)
133
+
134
+
135
+ # lr_finder2 = LRFinder(self, optimizer, criterion, device='cuda')
136
+ # lr_finder2.range_test(train_loader, end_lr=10, num_iter=200, step_mode="exp")
137
+ # lr_finder2.plot()
138
+ # suggested_lr = lr_finder2.suggest_lr()
139
+ # lr_finder2.reset()
140
+ # EPOCHS = 20
141
+ # STEPS_PER_EPOCH = 2000
142
+
143
+ # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
144
+ # max_lr=suggested_lr,
145
+ # steps_per_epoch=STEPS_PER_EPOCH,
146
+ # epochs=EPOCHS,
147
+ # pct_start=int(0.3*EPOCHS)/EPOCHS if EPOCHS != 1 else 0.5, # 30% of total number of Epochs
148
+ # div_factor=100,
149
+ # three_phase=False,
150
+ # final_div_factor=100,
151
+ # anneal_strategy="linear"
152
+ # )
153
+ return torch.optim.SGD(self.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=WEIGHT_DECAY)
154
+ # return scheduler
155
+
156
+ def prepare_data(self):
157
+ # download
158
+ Cifar10SearchDataset(self.data_dir, train=True, download=True)
159
+ Cifar10SearchDataset(self.data_dir, train=False, download=True)
160
+
161
+ def setup(self, stage=None):
162
+
163
+ # Assign train/val datasets for use in dataloaders
164
+ if stage == "fit" or stage is None:
165
+ cifar_full = Cifar10SearchDataset(self.data_dir, train=True, transform=self.train_transforms)
166
+ self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
167
+
168
+ # Assign test dataset for use in dataloader(s)
169
+ if stage == "test" or stage is None:
170
+ self.cifar_test = Cifar10SearchDataset(self.data_dir, train=False, transform=self.test_transforms)
171
+
172
+ def train_dataloader(self):
173
+ return DataLoader(self.cifar_train, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
174
+
175
+ def val_dataloader(self):
176
+ return DataLoader(self.cifar_val, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
177
+
178
+ def test_dataloader(self):
179
+ return DataLoader(self.cifar_test, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
utils.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+ import matplotlib.pyplot as plt
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ from torch_lr_finder import LRFinder
9
+ from pytorch_grad_cam import GradCAM
10
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
11
+ from pytorch_grad_cam.utils.image import show_cam_on_image
12
+
13
+
14
+ import albumentations as A
15
+ from albumentations.pytorch import ToTensorV2
16
+
17
+ # Train data transformations
18
+ means = [0.4914, 0.4822, 0.4465]
19
+ stds = [0.2470, 0.2435, 0.2616]
20
+
21
+ train_transforms = A.Compose(
22
+ [
23
+ A.Normalize(mean=means, std=stds, always_apply=True),
24
+ A.PadIfNeeded(min_height=36, min_width=36, always_apply=True),
25
+ A.RandomCrop(height=32, width=32, always_apply=True),
26
+ A.HorizontalFlip(),
27
+ A.CoarseDropout(max_holes=1, max_height=16, max_width=16, min_holes=1, min_height=8, min_width=8, fill_value=means),
28
+ ToTensorV2(),
29
+ ]
30
+ )
31
+
32
+ test_transforms = A.Compose(
33
+ [
34
+ A.Normalize(mean=means, std=stds, always_apply=True),
35
+ ToTensorV2(),
36
+ ]
37
+ )
38
+
39
+
40
+ class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
41
+
42
+ def __init__(self, root="~/data", train=True, download=True, transform=None):
43
+ super().__init__(root=root, train=train, download=download, transform=transform)
44
+
45
+ def __getitem__(self, index):
46
+ image, label = self.data[index], self.targets[index]
47
+ if self.transform is not None:
48
+ transformed = self.transform(image=image)
49
+ image = transformed["image"]
50
+ return image, label
51
+
52
+ def dataloader(data_path,batch_size):#,train_transforms,test_transforms):
53
+ trainset = Cifar10SearchDataset(root=data_path, train=True,download=True, transform=train_transforms)
54
+ trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)
55
+
56
+ testset = Cifar10SearchDataset(root=data_path, train=False, download=True, transform=test_transforms)
57
+ testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,shuffle=False)
58
+ classes = trainset.classes
59
+ return trainloader, testloader, classes
60
+
61
+ def plot_sample_data(dataloader):
62
+ batch_data, batch_label = next(iter(dataloader))
63
+ fig = plt.figure()
64
+ for i in range(12):
65
+ plt.subplot(3, 4, i + 1)
66
+ plt.tight_layout()
67
+ plt.imshow(torch.permute(batch_data[i], (1, 2, 0)))
68
+ plt.title(batch_label[i].item())
69
+ plt.xticks([])
70
+ plt.yticks([])
71
+
72
+ class trainer:
73
+ def __init__(self,model,device,optimizer,scheduler):
74
+ self.model = model
75
+ self.device = device
76
+ self.optimizer = optimizer
77
+ self.scheduler = scheduler
78
+ self.device = device
79
+
80
+ self.train_losses = []
81
+ self.test_losses = []
82
+ self.train_acc = []
83
+ self.test_acc = []
84
+
85
+ def getcorrectpredcount(self,prediction, labels):
86
+ return prediction.argmax(dim=1).eq(labels).sum().item()
87
+
88
+ def train(self,train_loader):
89
+ self.model.train()
90
+ pbar = tqdm(train_loader)
91
+
92
+ train_loss = 0
93
+ correct = 0
94
+ processed = 0
95
+ criterion = nn.CrossEntropyLoss()
96
+
97
+ for batch_idx, (data, target) in enumerate(pbar):
98
+ data, target = data.to(self.device), target.to(self.device)
99
+ self.optimizer.zero_grad()
100
+
101
+ # Predict
102
+ pred = self.model(data)
103
+
104
+ # Calculate loss
105
+ loss = criterion(pred, target)
106
+ train_loss += loss.item()
107
+
108
+ # Backpropagation
109
+ loss.backward()
110
+ self.optimizer.step()
111
+
112
+ correct += self.getcorrectpredcount(pred, target)
113
+ processed += len(data)
114
+
115
+ pbar.set_description(
116
+ desc=f'Train: Loss={loss.item():0.4f} Batch_id={batch_idx} Accuracy={100 * correct / processed:0.2f}')
117
+
118
+ self.train_acc.append(100 * correct / processed)
119
+ self.train_losses.append(train_loss / len(train_loader))
120
+ return self.train_acc, self.train_losses
121
+
122
+ def test(self,test_loader):
123
+ self.model.eval()
124
+
125
+ test_loss = 0
126
+ correct = 0
127
+
128
+ with torch.no_grad():
129
+ for batch_idx, (data, target) in enumerate(test_loader):
130
+ data, target = data.to(self.device), target.to(self.device)
131
+
132
+ output = self.model(data)
133
+ test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
134
+
135
+ correct += self.getcorrectpredcount(output, target)
136
+
137
+ test_loss /= len(test_loader.dataset)
138
+ self.test_acc.append(100. * correct / len(test_loader.dataset))
139
+ self.test_losses.append(test_loss)
140
+
141
+ print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
142
+ test_loss, correct, len(test_loader.dataset),
143
+ 100. * correct / len(test_loader.dataset)))
144
+
145
+ return self.test_acc, self.test_losses
146
+
147
+ def visualize_graphs(self):
148
+ t = [t_items.item() for t_items in self.train_losses]
149
+ fig, axs = plt.subplots(2,2,figsize=(15,10))
150
+ axs[0, 0].plot(t)
151
+ axs[0, 0].set_title("Training Loss")
152
+ axs[1, 0].plot(self.train_acc[4000:])
153
+ axs[1, 0].set_title("Training Accuracy")
154
+ axs[0, 1].plot(self.test_losses)
155
+ axs[0, 1].set_title("Test Loss")
156
+ axs[1, 1].plot(self.test_acc)
157
+ axs[1, 1].set_title("Test Accuracy")
158
+
159
+ def evaluate_all_class(self,classes,test_loader):
160
+
161
+ # prepare to count predictions for each class
162
+ correct_pred = {classname: 0 for classname in classes}
163
+ total_pred = {classname: 0 for classname in classes}
164
+
165
+ # again no gradients needed
166
+ with torch.no_grad():
167
+ for data in test_loader:
168
+ images, labels = data
169
+ outputs = self.model(images)
170
+ _, predictions = torch.max(outputs, 1)
171
+ # collect the correct predictions for each class
172
+ for label, prediction in zip(labels, predictions):
173
+ if label == prediction:
174
+ correct_pred[classes[label]] += 1
175
+ total_pred[classes[label]] += 1
176
+
177
+ # print accuracy for each class
178
+ for classname, correct_count in correct_pred.items():
179
+ accuracy = 100 * float(correct_count) / total_pred[classname]
180
+ print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
181
+
182
+
183
+
184
+ def evaluate_model(model, loader, device):
185
+ cols, rows = 4, 6
186
+ figure = plt.figure(figsize=(20, 20))
187
+ for index in range(1, cols * rows + 1):
188
+ k = np.random.randint(0, len(loader.dataset)) # random points from test dataset
189
+
190
+ img, label = loader.dataset[k] # separate the image and label
191
+ img = img.unsqueeze(0) # adding one dimention
192
+ pred = model(img.to(device)) # Prediction
193
+
194
+ figure.add_subplot(rows, cols, index) # making the figure
195
+ plt.title(f"Predcited label {pred.argmax().item()}\n True Label: {label}") # title of plot
196
+ plt.axis("off") # hiding the axis
197
+ plt.imshow(img.squeeze(), cmap="gray") # showing the plot
198
+
199
+ plt.show()
200
+
201
+ def get_lr(optimizer):
202
+ """"
203
+ for tracking how your learning rate is changing throughout training
204
+ """
205
+ for param_group in optimizer.param_groups:
206
+ return param_group['lr']
207
+
208
+
209
+ def lr_calc(model, train_loader, optimizer, criterion):
210
+ # model = Net().to(device)
211
+ # optimizer = optim.Adam(model.parameters(), lr=0.03, weight_decay=1e-4)
212
+ # criterion = nn.CrossEntropyLoss()
213
+ lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
214
+ lr_finder.range_test(train_loader, end_lr=10, num_iter=200, step_mode="exp")
215
+ lr_finder.plot() # to inspect the loss-learning rate graph
216
+ lr_finder.reset() # to reset the model and optimizer to their initial state
217
+
218
+
219
+ def unnormalize(img):
220
+ channel_means = (0.4914, 0.4822, 0.4465)
221
+ channel_stdevs = (0.2470, 0.2435, 0.2616)
222
+ img = img.numpy().astype(dtype=np.float32)
223
+
224
+ for i in range(img.shape[0]):
225
+ img[i] = (img[i]*channel_stdevs[i])+channel_means[i]
226
+
227
+ return np.transpose(img, (1,2,0))
228
+
229
+
230
+ def plot_grad_cam_images(model, test_loader, classes, device):
231
+ # set model to evaluation mode
232
+ model.eval()
233
+ target_layers = [model.layer4[-1]]
234
+
235
+ # Construct the CAM object once, and then re-use it on many images:
236
+ cam = GradCAM(model=model, target_layers=target_layers)
237
+
238
+ misclassified_images = []
239
+ actual_labels = []
240
+ actual_targets = []
241
+ predicted_labels = []
242
+
243
+ with torch.no_grad():
244
+ for data, target in test_loader:
245
+ data, target = data.to(device), target.to(device)
246
+ output = model(data)
247
+ _, pred = torch.max(output, 1)
248
+ for i in range(len(pred)):
249
+ if pred[i] != target[i]:
250
+ actual_targets.append(target[i])
251
+ misclassified_images.append(data[i])
252
+ actual_labels.append(classes[target[i]])
253
+ predicted_labels.append(classes[pred[i]])
254
+
255
+ # Plot the misclassified images
256
+ fig = plt.figure(figsize=(12, 5))
257
+ for i in range(10):
258
+ sub = fig.add_subplot(2, 5, i+1)
259
+ input_tensor = misclassified_images[i].unsqueeze(dim=0)
260
+ targets = [ClassifierOutputTarget(actual_targets[i])]
261
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
262
+ grayscale_cam = grayscale_cam[0, :]
263
+
264
+ visualization = show_cam_on_image(unnormalize(misclassified_images[i].cpu()), grayscale_cam, use_rgb=True,image_weight=0.7)
265
+
266
+ plt.imshow(visualization)
267
+ sub.set_title("Actual: {}, Pred: {}".format(actual_labels[i], predicted_labels[i]), color='red')
268
+ plt.tight_layout()
269
+ plt.show()
270
+
271
+ def plot_misclassified_images(model, test_loader, classes, device):
272
+ # set model to evaluation mode
273
+ model.eval()
274
+
275
+ misclassified_images = []
276
+ actual_labels = []
277
+ predicted_labels = []
278
+
279
+ with torch.no_grad():
280
+ for data, target in test_loader:
281
+ data, target = data.to(device), target.to(device)
282
+ output = model(data)
283
+ _, pred = torch.max(output, 1)
284
+ for i in range(len(pred)):
285
+ if pred[i] != target[i]:
286
+ misclassified_images.append(data[i])
287
+ actual_labels.append(classes[target[i]])
288
+ predicted_labels.append(classes[pred[i]])
289
+
290
+ # Plot the misclassified images
291
+ fig = plt.figure(figsize=(12, 5))
292
+ for i in range(10):
293
+ sub = fig.add_subplot(2, 5, i+1)
294
+ npimg = unnormalize(misclassified_images[i].cpu())
295
+ plt.imshow(npimg, cmap='gray', interpolation='none')
296
+ sub.set_title("Actual: {}, Pred: {}".format(actual_labels[i], predicted_labels[i]),color='red')
297
+ plt.tight_layout()
298
+ plt.show()
visualize.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Function used for visualization of data and results
4
+ Author: Shilpaj Bhalerao
5
+ Date: Jun 21, 2023
6
+ """
7
+ # Standard Library Imports
8
+ import math
9
+ from dataclasses import dataclass
10
+ from typing import NoReturn
11
+
12
+ # Third-Party Imports
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import pandas as pd
16
+ import seaborn as sn
17
+ import torch
18
+ import torch.nn as nn
19
+ from torchvision import transforms
20
+ from sklearn.metrics import confusion_matrix
21
+
22
+
23
+ # ---------------------------- DATA SAMPLES ----------------------------
24
+ def display_mnist_data_samples(dataset: 'DataLoader object', number_of_samples: int) -> NoReturn:
25
+ """
26
+ Function to display samples for dataloader
27
+ :param dataset: Train or Test dataset transformed to Tensor
28
+ :param number_of_samples: Number of samples to be displayed
29
+ """
30
+ # Get batch from the data_set
31
+ batch_data = []
32
+ batch_label = []
33
+ for count, item in enumerate(dataset):
34
+ if not count <= number_of_samples:
35
+ break
36
+ batch_data.append(item[0])
37
+ batch_label.append(item[1])
38
+
39
+ # Plot the samples from the batch
40
+ fig = plt.figure()
41
+ x_count = 5
42
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
43
+
44
+ # Plot the samples from the batch
45
+ for i in range(number_of_samples):
46
+ plt.subplot(y_count, x_count, i + 1)
47
+ plt.tight_layout()
48
+ plt.imshow(batch_data[i].squeeze(), cmap='gray')
49
+ plt.title(batch_label[i])
50
+ plt.xticks([])
51
+ plt.yticks([])
52
+
53
+
54
+ def display_cifar_data_samples(data_set, number_of_samples: int, classes: list):
55
+ """
56
+ Function to display samples for data_set
57
+ :param data_set: Train or Test data_set transformed to Tensor
58
+ :param number_of_samples: Number of samples to be displayed
59
+ :param classes: Name of classes to be displayed
60
+ """
61
+ # Get batch from the data_set
62
+ batch_data = []
63
+ batch_label = []
64
+ for count, item in enumerate(data_set):
65
+ if not count <= number_of_samples:
66
+ break
67
+ batch_data.append(item[0])
68
+ batch_label.append(item[1])
69
+ batch_data = torch.stack(batch_data, dim=0).numpy()
70
+
71
+ # Plot the samples from the batch
72
+ fig = plt.figure()
73
+ x_count = 5
74
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
75
+
76
+ for i in range(number_of_samples):
77
+ plt.subplot(y_count, x_count, i + 1)
78
+ plt.tight_layout()
79
+ plt.imshow(np.transpose(batch_data[i].squeeze(), (1, 2, 0)))
80
+ plt.title(classes[batch_label[i]])
81
+ plt.xticks([])
82
+ plt.yticks([])
83
+
84
+
85
+ # ---------------------------- MISCLASSIFIED DATA ----------------------------
86
+ def display_cifar_misclassified_data(data: list,
87
+ classes: list[str],
88
+ inv_normalize: transforms.Normalize,
89
+ number_of_samples: int = 10):
90
+ """
91
+ Function to plot images with labels
92
+ :param data: List[Tuple(image, label)]
93
+ :param classes: Name of classes in the dataset
94
+ :param inv_normalize: Mean and Standard deviation values of the dataset
95
+ :param number_of_samples: Number of images to print
96
+ """
97
+ fig = plt.figure(figsize=(10, 10))
98
+
99
+ x_count = 5
100
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
101
+
102
+ for i in range(number_of_samples):
103
+ plt.subplot(y_count, x_count, i + 1)
104
+ img = data[i][0].squeeze().to('cpu')
105
+ img = inv_normalize(img)
106
+ plt.imshow(np.transpose(img, (1, 2, 0)))
107
+ plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()])
108
+ plt.xticks([])
109
+ plt.yticks([])
110
+
111
+
112
+ def display_mnist_misclassified_data(data: list,
113
+ number_of_samples: int = 10):
114
+ """
115
+ Function to plot images with labels
116
+ :param data: List[Tuple(image, label)]
117
+ :param number_of_samples: Number of images to print
118
+ """
119
+ fig = plt.figure(figsize=(8, 5))
120
+
121
+ x_count = 5
122
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
123
+
124
+ for i in range(number_of_samples):
125
+ plt.subplot(y_count, x_count, i + 1)
126
+ img = data[i][0].squeeze(0).to('cpu')
127
+ plt.imshow(np.transpose(img, (1, 2, 0)), cmap='gray')
128
+ plt.title(r"Correct: " + str(data[i][1].item()) + '\n' + 'Output: ' + str(data[i][2].item()))
129
+ plt.xticks([])
130
+ plt.yticks([])
131
+
132
+
133
+ # ---------------------------- AUGMENTATION SAMPLES ----------------------------
134
+ def visualize_cifar_augmentation(data_set, data_transforms):
135
+ """
136
+ Function to visualize the augmented data
137
+ :param data_set: Dataset without transformations
138
+ :param data_transforms: Dictionary of transforms
139
+ """
140
+ sample, label = data_set[6]
141
+ total_augmentations = len(data_transforms)
142
+
143
+ fig = plt.figure(figsize=(10, 5))
144
+ for count, (key, trans) in enumerate(data_transforms.items()):
145
+ if count == total_augmentations - 1:
146
+ break
147
+ plt.subplot(math.ceil(total_augmentations / 5), 5, count + 1)
148
+ augmented = trans(image=sample)['image']
149
+ plt.imshow(augmented)
150
+ plt.title(key)
151
+ plt.xticks([])
152
+ plt.yticks([])
153
+
154
+
155
+ def visualize_mnist_augmentation(data_set, data_transforms):
156
+ """
157
+ Function to visualize the augmented data
158
+ :param data_set: Dataset to visualize the augmentations
159
+ :param data_transforms: Dictionary of transforms
160
+ """
161
+ sample, label = data_set[6]
162
+ total_augmentations = len(data_transforms)
163
+
164
+ fig = plt.figure(figsize=(10, 5))
165
+ for count, (key, trans) in enumerate(data_transforms.items()):
166
+ if count == total_augmentations - 1:
167
+ break
168
+ plt.subplot(math.ceil(total_augmentations / 5), 5, count + 1)
169
+ img = trans(sample).to('cpu')
170
+ plt.imshow(np.transpose(img, (1, 2, 0)), cmap='gray')
171
+ plt.title(key)
172
+ plt.xticks([])
173
+ plt.yticks([])
174
+
175
+
176
+ # ---------------------------- LOSS AND ACCURACIES ----------------------------
177
+ def display_loss_and_accuracies(train_losses: list,
178
+ train_acc: list,
179
+ test_losses: list,
180
+ test_acc: list,
181
+ plot_size: tuple = (10, 10)) -> NoReturn:
182
+ """
183
+ Function to display training and test information(losses and accuracies)
184
+ :param train_losses: List containing training loss of each epoch
185
+ :param train_acc: List containing training accuracy of each epoch
186
+ :param test_losses: List containing test loss of each epoch
187
+ :param test_acc: List containing test accuracy of each epoch
188
+ :param plot_size: Size of the plot
189
+ """
190
+ # Create a plot of 2x2 of size
191
+ fig, axs = plt.subplots(2, 2, figsize=plot_size)
192
+
193
+ # Plot the training loss and accuracy for each epoch
194
+ axs[0, 0].plot(train_losses)
195
+ axs[0, 0].set_title("Training Loss")
196
+ axs[1, 0].plot(train_acc)
197
+ axs[1, 0].set_title("Training Accuracy")
198
+
199
+ # Plot the test loss and accuracy for each epoch
200
+ axs[0, 1].plot(test_losses)
201
+ axs[0, 1].set_title("Test Loss")
202
+ axs[1, 1].plot(test_acc)
203
+ axs[1, 1].set_title("Test Accuracy")
204
+
205
+
206
+ # ---------------------------- Feature Maps and Kernels ----------------------------
207
+
208
+ @dataclass
209
+ class ConvLayerInfo:
210
+ """
211
+ Data Class to store Conv layer's information
212
+ """
213
+ layer_number: int
214
+ weights: torch.nn.parameter.Parameter
215
+ layer_info: torch.nn.modules.conv.Conv2d
216
+
217
+
218
+ class FeatureMapVisualizer:
219
+ """
220
+ Class to visualize Feature Map of the Layers
221
+ """
222
+
223
+ def __init__(self, model):
224
+ """
225
+ Contructor
226
+ :param model: Model Architecture
227
+ """
228
+ self.conv_layers = []
229
+ self.outputs = []
230
+ self.layerwise_kernels = None
231
+
232
+ # Disect the model
233
+ counter = 0
234
+ model_children = model.children()
235
+ for children in model_children:
236
+ if type(children) == nn.Sequential:
237
+ for child in children:
238
+ if type(child) == nn.Conv2d:
239
+ counter += 1
240
+ self.conv_layers.append(ConvLayerInfo(layer_number=counter,
241
+ weights=child.weight,
242
+ layer_info=child)
243
+ )
244
+
245
+ def get_model_weights(self):
246
+ """
247
+ Method to get the model weights
248
+ """
249
+ model_weights = [layer.weights for layer in self.conv_layers]
250
+ return model_weights
251
+
252
+ def get_conv_layers(self):
253
+ """
254
+ Get the convolution layers
255
+ """
256
+ conv_layers = [layer.layer_info for layer in self.conv_layers]
257
+ return conv_layers
258
+
259
+ def get_total_conv_layers(self) -> int:
260
+ """
261
+ Get total number of convolution layers
262
+ """
263
+ out = self.get_conv_layers()
264
+ return len(out)
265
+
266
+ def feature_maps_of_all_kernels(self, image: torch.Tensor) -> dict:
267
+ """
268
+ Get feature maps from all the kernels of all the layers
269
+ :param image: Image to be passed to the network
270
+ """
271
+ image = image.unsqueeze(0)
272
+ image = image.to('cpu')
273
+
274
+ outputs = {}
275
+
276
+ layers = self.get_conv_layers()
277
+ for index, layer in enumerate(layers):
278
+ image = layer(image)
279
+ outputs[str(layer)] = image
280
+ self.outputs = outputs
281
+ return outputs
282
+
283
+ def visualize_feature_map_of_kernel(self, image: torch.Tensor, kernel_number: int) -> None:
284
+ """
285
+ Function to visualize feature map of kernel number from each layer
286
+ :param image: Image passed to the network
287
+ :param kernel_number: Number of kernel in each layer (Should be less than or equal to the minimum number of kernel in the network)
288
+ """
289
+ # List to store processed feature maps
290
+ processed = []
291
+
292
+ # Get feature maps from all kernels of all the conv layers
293
+ outputs = self.feature_maps_of_all_kernels(image)
294
+
295
+ # Extract the n_th kernel's output from each layer and convert it to grayscale
296
+ for feature_map in outputs.values():
297
+ try:
298
+ feature_map = feature_map[0][kernel_number]
299
+ except IndexError:
300
+ print("Filter number should be less than the minimum number of channels in a network")
301
+ break
302
+ finally:
303
+ gray_scale = feature_map / feature_map.shape[0]
304
+ processed.append(gray_scale.data.numpy())
305
+
306
+ # Plot the Feature maps with layer and kernel number
307
+ x_range = len(outputs) // 5 + 4
308
+ fig = plt.figure(figsize=(10, 10))
309
+ for i in range(len(processed)):
310
+ a = fig.add_subplot(x_range, 5, i + 1)
311
+ imgplot = plt.imshow(processed[i])
312
+ a.axis("off")
313
+ title = f"{list(outputs.keys())[i].split('(')[0]}_l{i + 1}_k{kernel_number}"
314
+ a.set_title(title, fontsize=10)
315
+
316
+ def get_max_kernel_number(self):
317
+ """
318
+ Function to get maximum number of kernels in the network (for a layer)
319
+ """
320
+ layers = self.get_conv_layers()
321
+ channels = [layer.out_channels for layer in layers]
322
+ self.layerwise_kernels = channels
323
+ return max(channels)
324
+
325
+ def visualize_kernels_from_layer(self, layer_number: int):
326
+ """
327
+ Visualize Kernels from a layer
328
+ :param layer_number: Number of layer from which kernels are to be visualized
329
+ """
330
+ # Get the kernels number for each layer
331
+ self.get_max_kernel_number()
332
+
333
+ # Zero Indexing
334
+ layer_number = layer_number - 1
335
+ _kernels = self.layerwise_kernels[layer_number]
336
+
337
+ grid = math.ceil(math.sqrt(_kernels))
338
+
339
+ plt.figure(figsize=(5, 4))
340
+ model_weights = self.get_model_weights()
341
+ _layer_weights = model_weights[layer_number].cpu()
342
+ for i, filter in enumerate(_layer_weights):
343
+ plt.subplot(grid, grid, i + 1)
344
+ plt.imshow(filter[0, :, :].detach(), cmap='gray')
345
+ plt.axis('off')
346
+ plt.show()
347
+
348
+
349
+ # ---------------------------- Confusion Matrix ----------------------------
350
+ def visualize_confusion_matrix(classes: list[str], device: str, model: 'DL Model',
351
+ test_loader: torch.utils.data.DataLoader):
352
+ """
353
+ Function to generate and visualize confusion matrix
354
+ :param classes: List of class names
355
+ :param device: cuda/cpu
356
+ :param model: Model Architecture
357
+ :param test_loader: DataLoader for test set
358
+ """
359
+ nb_classes = len(classes)
360
+ device = 'cuda'
361
+ cm = torch.zeros(nb_classes, nb_classes)
362
+
363
+ model.eval()
364
+ with torch.no_grad():
365
+ for inputs, labels in test_loader:
366
+ inputs = inputs.to(device)
367
+ labels = labels.to(device)
368
+ model = model.to(device)
369
+
370
+ preds = model(inputs)
371
+ preds = preds.argmax(dim=1)
372
+
373
+ for t, p in zip(labels.view(-1), preds.view(-1)):
374
+ cm[t, p] = cm[t, p] + 1
375
+
376
+ # Build confusion matrix
377
+ labels = labels.to('cpu')
378
+ preds = preds.to('cpu')
379
+ cf_matrix = confusion_matrix(labels, preds)
380
+ df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None],
381
+ index=[i for i in classes],
382
+ columns=[i for i in classes])
383
+ plt.figure(figsize=(12, 7))
384
+ sn.heatmap(df_cm, annot=True)