Shivdutta commited on
Commit
4e9cc3f
·
verified ·
1 Parent(s): 8d6c0cb

Upload 3 files

Browse files
Files changed (3) hide show
  1. datasets.py +35 -0
  2. resnetS11.py +397 -0
  3. utils.py +568 -0
datasets.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ from torchvision import datasets, transforms
3
+
4
+
5
+ class TransformedDataset(datasets.CIFAR10):
6
+ """
7
+ Custom dataset class extending CIFAR10 dataset with additional transformation capabilities.
8
+
9
+ Args:
10
+ root (str, optional): Root directory where the dataset is stored. Default is "./data".
11
+ train (bool, optional): Specifies if the dataset is for training or testing. Default is True.
12
+ download (bool, optional): If True, downloads the dataset from the internet and places it in the root directory. Default is True.
13
+ transform (list, optional): List of transformations to apply to the images. Default is None.
14
+
15
+ """
16
+ def __init__(self, root: str = "./data", train: bool = True, download: bool = True, transform: list = None):
17
+ super().__init__(root=root, train=train, download=download, transform=transform)
18
+
19
+ def __getitem__(self, index: int) -> Tuple:
20
+ """
21
+ Retrieves the item at the specified index.
22
+
23
+ Args:
24
+ index (int): Index of the item to retrieve.
25
+
26
+ Returns:
27
+ Tuple: A tuple containing the transformed image and its label.
28
+
29
+ """
30
+ image, label = self.data[index], self.targets[index]
31
+
32
+ if self.transform:
33
+ transformed = self.transform(image=image)
34
+ image = transformed["image"]
35
+ return image, label
resnetS11.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import albumentations as A
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.optim as optim
11
+ from torch.utils.data import DataLoader, random_split
12
+
13
+ from torchvision import transforms
14
+ from torchvision.datasets import CIFAR10
15
+
16
+ from pytorch_lightning import LightningModule, Trainer
17
+ from torchmetrics import Accuracy
18
+
19
+ from datasets import TransformedDataset
20
+ from utils import get_cifar_statistics
21
+ from utils import visualize_cifar_augmentation, display_cifar_data_samples
22
+
23
+
24
+ class BasicBlock(LightningModule):
25
+ expansion = 1
26
+
27
+ def __init__(self, in_planes, planes, stride=1):
28
+ super(BasicBlock, self).__init__()
29
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
30
+ self.bn1 = nn.BatchNorm2d(planes)
31
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
32
+ self.bn2 = nn.BatchNorm2d(planes)
33
+
34
+ self.shortcut = nn.Sequential()
35
+ if stride != 1 or in_planes != self.expansion*planes:
36
+ self.shortcut = nn.Sequential(
37
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
38
+ nn.BatchNorm2d(self.expansion*planes)
39
+ )
40
+
41
+ def forward(self, x):
42
+ out = F.relu(self.bn1(self.conv1(x)))
43
+ out = self.bn2(self.conv2(out))
44
+ out += self.shortcut(x)
45
+ out = F.relu(out)
46
+ return out
47
+
48
+ class LITResNet(LightningModule):
49
+ def __init__(self, class_names, data_dir='/data/'):
50
+ """
51
+ Constructor
52
+ """
53
+ # Initialize the Module class
54
+ super(LITResNet,self).__init__()
55
+
56
+ # Initialize variables
57
+ self.classes = class_names
58
+ self.data_dir = data_dir
59
+ self.num_classes = 10
60
+ self._learning_rate = 0.03
61
+ self.inv_normalize = transforms.Normalize(
62
+ mean=[-0.50 / 0.23, -0.50 / 0.23, -0.50 / 0.23],
63
+ std=[1 / 0.23, 1 / 0.23, 1 / 0.23]
64
+ )
65
+ self.batch_size = 512
66
+ self.epochs = 24
67
+ self.accuracy = Accuracy(task='multiclass',
68
+ num_classes=10)
69
+ self.train_transforms = transforms.Compose([transforms.ToTensor()])
70
+ self.test_transforms = transforms.Compose([transforms.ToTensor()])
71
+ self.stats_train = None
72
+ self.stats_test = None
73
+ self.cifar10_train = None
74
+ self.cifar10_test = None
75
+ self.cifar10_val = None
76
+ self.misclassified_data = None
77
+
78
+ # Defined Layers for the model
79
+ self.prep_layer = None
80
+ self.custom_block1 = None
81
+ self.custom_block2 = None
82
+ self.custom_block3 = None
83
+ self.resnet_block1 = None
84
+ self.resnet_block3 = None
85
+ self.pool4 = None
86
+ self.fc = None
87
+ self.dropout_value = None
88
+ self.model_layers(BasicBlock, [2, 2, 2, 2])
89
+
90
+ # ##################################################################################################
91
+ # ################################ Model Architecture Related Hooks ################################
92
+ # ##################################################################################################
93
+ def model_layers(self, block, num_blocks, num_classes=10):
94
+ """
95
+ Method to initialize layers for the model
96
+ """
97
+ self.in_planes = 64
98
+
99
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
100
+ self.bn1 = nn.BatchNorm2d(64)
101
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
102
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
103
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
104
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
105
+ self.linear = nn.Linear(512*block.expansion, num_classes)
106
+
107
+ def _make_layer(self, block, planes, num_blocks, stride):
108
+ strides = [stride] + [1]*(num_blocks-1)
109
+ layers = []
110
+ for stride in strides:
111
+ layers.append(block(self.in_planes, planes, stride))
112
+ self.in_planes = planes * block.expansion
113
+ return nn.Sequential(*layers)
114
+
115
+ def forward(self, x):
116
+ """
117
+ Forward pass for model training
118
+ :param x: Input layer
119
+ :return: Model Prediction
120
+ """
121
+ out = F.relu(self.bn1(self.conv1(x)))
122
+ out = self.layer1(out)
123
+ out = self.layer2(out)
124
+ out = self.layer3(out)
125
+ out = self.layer4(out)
126
+ out = F.avg_pool2d(out, 4)
127
+ out = out.view(out.size(0), -1)
128
+ out = self.linear(out)
129
+ return out
130
+
131
+ # ##################################################################################################
132
+ # ############################## Training Configuration Related Hooks ##############################
133
+ # ##################################################################################################
134
+
135
+ def configure_optimizers(self):
136
+ """
137
+ Method to configure the optimizer and learning rate scheduler
138
+ """
139
+ learning_rate = 0.03
140
+ weight_decay = 1e-4
141
+ optimizer = optim.Adam(self.parameters(), lr=learning_rate, weight_decay=weight_decay)
142
+
143
+ # Scheduler
144
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
145
+ max_lr=self._learning_rate,
146
+ steps_per_epoch=len(self.train_dataloader()),
147
+ epochs=self.epochs,
148
+ pct_start=5 / self.epochs,
149
+ div_factor=100,
150
+ three_phase=False,
151
+ final_div_factor=100,
152
+ anneal_strategy="linear"
153
+ )
154
+ return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]
155
+
156
+ @property
157
+ def learning_rate(self) -> float:
158
+ """
159
+ Method to get the learning rate value
160
+ """
161
+ return self._learning_rate
162
+
163
+ @learning_rate.setter
164
+ def learning_rate(self, value: float):
165
+ """
166
+ Method to set the learning rate value
167
+ :param value: Updated value of learning rate
168
+ """
169
+ self._learning_rate = value
170
+
171
+ def set_training_confi(self, *, epochs, batch_size):
172
+ """
173
+ Method to set parameters required for model training
174
+ :param epochs: Number of epochs for which model is to be trained
175
+ :param batch_size: Batch Size
176
+ """
177
+ self.epochs = epochs
178
+ self.batch_size = batch_size
179
+
180
+ # #################################################################################################
181
+ # ################################## Training Loop Related Hooks ##################################
182
+ # #################################################################################################
183
+ def training_step(self, train_batch, batch_index):
184
+ """
185
+ Method called on training dataset to train the model
186
+ :param train_batch: Batch containing images and labels
187
+ :param batch_index: Index of the batch
188
+ """
189
+ x, y = train_batch
190
+ logits = self.forward(x)
191
+ loss = F.cross_entropy(logits, y)
192
+ preds = torch.argmax(logits, dim=1)
193
+ self.accuracy(preds, y)
194
+
195
+ self.log("train_loss", loss, prog_bar=True)
196
+ self.log("train_acc", self.accuracy, prog_bar=True)
197
+ return loss
198
+
199
+ def validation_step(self, batch, batch_idx):
200
+ """
201
+ Method called on validation dataset to check if the model is learning
202
+ :param batch: Batch containing images and labels
203
+ :param batch_idx: Index of the batch
204
+ """
205
+ x, y = batch
206
+ logits = self.forward(x)
207
+ loss = F.nll_loss(logits, y)
208
+ preds = torch.argmax(logits, dim=1)
209
+ self.accuracy(preds, y)
210
+
211
+ # Calling self.log will surface up scalars for you in TensorBoard
212
+ self.log("val_loss", loss, prog_bar=True)
213
+ self.log("val_acc", self.accuracy, prog_bar=True)
214
+ return loss
215
+
216
+ def test_step(self, batch, batch_idx):
217
+ """
218
+ Method called on test dataset to check model performance on unseen data
219
+ :param batch: Batch containing images and labels
220
+ :param batch_idx: Index of the batch
221
+ """
222
+ # Here we just reuse the validation_step for testing
223
+ return self.validation_step(batch, batch_idx)
224
+
225
+ # ##############################################################################################
226
+ # ##################################### Data Related Hooks #####################################
227
+ # ##############################################################################################
228
+
229
+ def set_transforms(self, train_set_transforms: dict, test_set_transforms: dict):
230
+ """
231
+ Method to set the transformations to be done on training and test datasets
232
+ :param train_set_transforms: Dictionary of transformations for training dataset
233
+ :param test_set_transforms: Dictionary of transformations for test dataset
234
+ """
235
+ self.train_transforms = A.Compose(train_set_transforms.values())
236
+ self.test_transforms = A.Compose(test_set_transforms.values())
237
+
238
+ def prepare_data(self):
239
+ """
240
+ Method to download the dataset
241
+ """
242
+ self.stats_train = CIFAR10('./data', train=True, download=True, transform=transforms.ToTensor())
243
+ self.stats_test = CIFAR10('./data', train=False, download=True, transform=transforms.ToTensor())
244
+
245
+ def setup(self, stage=None):
246
+ """
247
+ Method to create Split the dataset into train, test and val
248
+ """
249
+ # Only if dataset is not already split, perform the split operation
250
+ if not self.cifar10_train and not self.cifar10_test and not self.cifar10_val:
251
+
252
+ # Assign train/val datasets for use in dataloaders
253
+ if stage == "fit" or stage is None:
254
+ cifar10_full = TransformedDataset(self.data_dir, train=True, download=True, transform=self.train_transforms)
255
+ self.cifar10_train, self.cifar10_val = random_split(cifar10_full, [45_000, 5_000])
256
+
257
+ # Assign test dataset for use in dataloader(s)
258
+ if stage == "test" or stage is None:
259
+ self.cifar10_test = TransformedDataset(self.data_dir, train=False, download=True,
260
+ transform=self.test_transforms)
261
+
262
+ def train_dataloader(self):
263
+ """
264
+ Method to return the DataLoader for Training set
265
+ """
266
+ return DataLoader(self.cifar10_train, batch_size=self.batch_size, num_workers=os.cpu_count())
267
+
268
+ def val_dataloader(self):
269
+ """
270
+ Method to return the DataLoader for the Validation set
271
+ """
272
+ return DataLoader(self.cifar10_val, batch_size=self.batch_size, num_workers=os.cpu_count())
273
+
274
+ def test_dataloader(self):
275
+ """
276
+ Method to return the DataLoader for the Test set
277
+ """
278
+ return DataLoader(self.cifar10_test, batch_size=self.batch_size, num_workers=os.cpu_count())
279
+
280
+ def get_statistics(self, data_set_type="Train"):
281
+ """
282
+ Method to get the statistics for CIFAR10 dataset
283
+ """
284
+ # Execute self.prepare_data() only if not done earlier
285
+ if not self.stats_train and not self.stats_test:
286
+ self.prepare_data()
287
+
288
+ # Print stats for selected dataset
289
+ if data_set_type == "Train":
290
+ get_cifar_statistics(self.stats_train)
291
+ else:
292
+ get_cifar_statistics(self.stats_test, data_set_type="Test")
293
+
294
+ def display_data_samples(self, dataset="train", num_of_images=20):
295
+ """
296
+ Method to display data samples
297
+ """
298
+ # Execute self.prepare_data() only if not done earlier
299
+ try:
300
+ assert self.stats_train
301
+ except AttributeError:
302
+ self.prepare_data()
303
+
304
+ if dataset == "train":
305
+ display_cifar_data_samples(self.stats_train, num_of_images, self.classes)
306
+ else:
307
+ display_cifar_data_samples(self.stats_test, num_of_images, self.classes)
308
+
309
+ @staticmethod
310
+ def visualize_augmentation(aug_set_transforms: dict):
311
+ """
312
+ Method to visualize augmentations
313
+ :param aug_set_transforms: Dictionary of transformations to be visualized
314
+ """
315
+ aug_train = TransformedDataset('./data', train=True, download=True)
316
+ visualize_cifar_augmentation(aug_train, aug_set_transforms)
317
+
318
+ # #############################################################################################
319
+ # ############################## Misclassified Data Related Hooks ##############################
320
+ # #############################################################################################
321
+
322
+ def get_misclassified_data(self):
323
+ """
324
+ Function to run the model on test set and return misclassified images
325
+ """
326
+ if self.misclassified_data:
327
+ return self.misclassified_data
328
+
329
+ self.misclassified_data = []
330
+ self.prepare_data()
331
+ self.setup()
332
+
333
+ test_loader = self.test_dataloader()
334
+
335
+ # Reset the gradients
336
+ with torch.no_grad():
337
+ # Extract images, labels in a batch
338
+ for data, target in test_loader:
339
+
340
+ # Migrate the data to the device
341
+ data, target = data.to(self.device), target.to(self.device)
342
+
343
+ # Extract single image, label from the batch
344
+ for image, label in zip(data, target):
345
+
346
+ # Add batch dimension to the image
347
+ image = image.unsqueeze(0)
348
+
349
+ # Get the model prediction on the image
350
+ output = self.forward(image)
351
+
352
+ # Convert the output from one-hot encoding to a value
353
+ pred = output.argmax(dim=1, keepdim=True)
354
+
355
+ # If prediction is incorrect, append the data
356
+ if pred != label:
357
+ self.misclassified_data.append((image, label, pred))
358
+ return self.misclassified_data
359
+
360
+ def display_data_samples(self, dataset="train", num_of_images=20):
361
+ """
362
+ Method to display data samples
363
+ """
364
+ # Execute self.prepare_data() only if not done earlier
365
+ try:
366
+ assert self.stats_train
367
+ except AttributeError:
368
+ self.prepare_data()
369
+
370
+ if dataset == "train":
371
+ display_cifar_data_samples(self.stats_train, num_of_images, self.classes)
372
+ else:
373
+ display_cifar_data_samples(self.stats_test, num_of_images, self.classes)
374
+
375
+ def display_cifar_misclassified_data(self, number_of_samples: int = 10):
376
+ """
377
+ Function to plot images with labels
378
+ :param number_of_samples: Number of images to print
379
+ """
380
+ if not self.misclassified_data:
381
+ self.misclassified_data = self.get_misclassified_data()
382
+
383
+ fig = plt.figure(figsize=(10, 10))
384
+
385
+ x_count = 5
386
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
387
+
388
+ for i in range(number_of_samples):
389
+ plt.subplot(y_count, x_count, i + 1)
390
+ img = self.misclassified_data[i][0].squeeze().to('cpu')
391
+ img = self.inv_normalize(img)
392
+ plt.imshow(np.transpose(img, (1, 2, 0)))
393
+ plt.title(
394
+ r"Correct: " + self.classes[self.misclassified_data[i][1].item()] + '\n' + 'Output: ' + self.classes[
395
+ self.misclassified_data[i][2].item()])
396
+ plt.xticks([])
397
+ plt.yticks([])
utils.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import NoReturn
3
+
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ from torchsummary import summary
8
+ from torchvision import transforms
9
+ from pytorch_grad_cam import GradCAM
10
+ from pytorch_grad_cam.utils.image import show_cam_on_image
11
+
12
+ from dataclasses import dataclass
13
+ from typing import NoReturn
14
+ import pandas as pd
15
+ import seaborn as sn
16
+ import torch
17
+ import torch.nn as nn
18
+ from torchvision import transforms
19
+ from sklearn.metrics import confusion_matrix
20
+
21
+
22
+ # ---------------------------- DATA SAMPLES ----------------------------
23
+ def display_mnist_data_samples(dataset: 'DataLoader object', number_of_samples: int) -> NoReturn:
24
+ """
25
+ Function to display samples for dataloader
26
+ :param dataset: Train or Test dataset transformed to Tensor
27
+ :param number_of_samples: Number of samples to be displayed
28
+ """
29
+ # Get batch from the data_set
30
+ batch_data = []
31
+ batch_label = []
32
+ for count, item in enumerate(dataset):
33
+ if not count <= number_of_samples:
34
+ break
35
+ batch_data.append(item[0])
36
+ batch_label.append(item[1])
37
+
38
+ # Plot the samples from the batch
39
+ fig = plt.figure()
40
+ x_count = 5
41
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
42
+
43
+ # Plot the samples from the batch
44
+ for i in range(number_of_samples):
45
+ plt.subplot(y_count, x_count, i + 1)
46
+ plt.tight_layout()
47
+ plt.imshow(batch_data[i].squeeze(), cmap='gray')
48
+ plt.title(batch_label[i])
49
+ plt.xticks([])
50
+ plt.yticks([])
51
+
52
+
53
+ def display_cifar_data_samples(data_set, number_of_samples: int, classes: list):
54
+ """
55
+ Function to display samples for data_set
56
+ :param data_set: Train or Test data_set transformed to Tensor
57
+ :param number_of_samples: Number of samples to be displayed
58
+ :param classes: Name of classes to be displayed
59
+ """
60
+ # Get batch from the data_set
61
+ batch_data = []
62
+ batch_label = []
63
+ for count, item in enumerate(data_set):
64
+ if not count <= number_of_samples:
65
+ break
66
+ batch_data.append(item[0])
67
+ batch_label.append(item[1])
68
+ batch_data = torch.stack(batch_data, dim=0).numpy()
69
+
70
+ # Plot the samples from the batch
71
+ fig = plt.figure()
72
+ x_count = 5
73
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
74
+
75
+ for i in range(number_of_samples):
76
+ plt.subplot(y_count, x_count, i + 1)
77
+ plt.tight_layout()
78
+ plt.imshow(np.transpose(batch_data[i].squeeze(), (1, 2, 0)))
79
+ plt.title(classes[batch_label[i]])
80
+ plt.xticks([])
81
+ plt.yticks([])
82
+
83
+
84
+ # ---------------------------- MISCLASSIFIED DATA ----------------------------
85
+ def display_cifar_misclassified_data(data: list,
86
+ classes: list[str],
87
+ inv_normalize: transforms.Normalize,
88
+ number_of_samples: int = 10):
89
+ """
90
+ Function to plot images with labels
91
+ :param data: List[Tuple(image, label)]
92
+ :param classes: Name of classes in the dataset
93
+ :param inv_normalize: Mean and Standard deviation values of the dataset
94
+ :param number_of_samples: Number of images to print
95
+ """
96
+ fig = plt.figure(figsize=(10, 10))
97
+
98
+ x_count = 5
99
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
100
+
101
+ for i in range(number_of_samples):
102
+ plt.subplot(y_count, x_count, i + 1)
103
+ img = data[i][0].squeeze().to('cpu')
104
+ img = inv_normalize(img)
105
+ plt.imshow(np.transpose(img, (1, 2, 0)))
106
+ plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()])
107
+ plt.xticks([])
108
+ plt.yticks([])
109
+
110
+
111
+ def display_mnist_misclassified_data(data: list,
112
+ number_of_samples: int = 10):
113
+ """
114
+ Function to plot images with labels
115
+ :param data: List[Tuple(image, label)]
116
+ :param number_of_samples: Number of images to print
117
+ """
118
+ fig = plt.figure(figsize=(8, 5))
119
+
120
+ x_count = 5
121
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
122
+
123
+ for i in range(number_of_samples):
124
+ plt.subplot(y_count, x_count, i + 1)
125
+ img = data[i][0].squeeze(0).to('cpu')
126
+ plt.imshow(np.transpose(img, (1, 2, 0)), cmap='gray')
127
+ plt.title(r"Correct: " + str(data[i][1].item()) + '\n' + 'Output: ' + str(data[i][2].item()))
128
+ plt.xticks([])
129
+ plt.yticks([])
130
+
131
+
132
+ # ---------------------------- AUGMENTATION SAMPLES ----------------------------
133
+ def visualize_cifar_augmentation(data_set, data_transforms):
134
+ """
135
+ Function to visualize the augmented data
136
+ :param data_set: Dataset without transformations
137
+ :param data_transforms: Dictionary of transforms
138
+ """
139
+ sample, label = data_set[6]
140
+ total_augmentations = len(data_transforms)
141
+
142
+ fig = plt.figure(figsize=(10, 5))
143
+ for count, (key, trans) in enumerate(data_transforms.items()):
144
+ if count == total_augmentations - 1:
145
+ break
146
+ plt.subplot(math.ceil(total_augmentations / 5), 5, count + 1)
147
+ augmented = trans(image=sample)['image']
148
+ plt.imshow(augmented)
149
+ plt.title(key)
150
+ plt.xticks([])
151
+ plt.yticks([])
152
+
153
+
154
+ def visualize_mnist_augmentation(data_set, data_transforms):
155
+ """
156
+ Function to visualize the augmented data
157
+ :param data_set: Dataset to visualize the augmentations
158
+ :param data_transforms: Dictionary of transforms
159
+ """
160
+ sample, label = data_set[6]
161
+ total_augmentations = len(data_transforms)
162
+
163
+ fig = plt.figure(figsize=(10, 5))
164
+ for count, (key, trans) in enumerate(data_transforms.items()):
165
+ if count == total_augmentations - 1:
166
+ break
167
+ plt.subplot(math.ceil(total_augmentations / 5), 5, count + 1)
168
+ img = trans(sample).to('cpu')
169
+ plt.imshow(np.transpose(img, (1, 2, 0)), cmap='gray')
170
+ plt.title(key)
171
+ plt.xticks([])
172
+ plt.yticks([])
173
+
174
+
175
+ # ---------------------------- LOSS AND ACCURACIES ----------------------------
176
+ def display_loss_and_accuracies(train_losses: list,
177
+ train_acc: list,
178
+ test_losses: list,
179
+ test_acc: list,
180
+ plot_size: tuple = (10, 10)) -> NoReturn:
181
+ """
182
+ Function to display training and test information(losses and accuracies)
183
+ :param train_losses: List containing training loss of each epoch
184
+ :param train_acc: List containing training accuracy of each epoch
185
+ :param test_losses: List containing test loss of each epoch
186
+ :param test_acc: List containing test accuracy of each epoch
187
+ :param plot_size: Size of the plot
188
+ """
189
+ # Create a plot of 2x2 of size
190
+ fig, axs = plt.subplots(2, 2, figsize=plot_size)
191
+
192
+ # Plot the training loss and accuracy for each epoch
193
+ axs[0, 0].plot(train_losses)
194
+ axs[0, 0].set_title("Training Loss")
195
+ axs[1, 0].plot(train_acc)
196
+ axs[1, 0].set_title("Training Accuracy")
197
+
198
+ # Plot the test loss and accuracy for each epoch
199
+ axs[0, 1].plot(test_losses)
200
+ axs[0, 1].set_title("Test Loss")
201
+ axs[1, 1].plot(test_acc)
202
+ axs[1, 1].set_title("Test Accuracy")
203
+
204
+
205
+ # ---------------------------- Feature Maps and Kernels ----------------------------
206
+
207
+ @dataclass
208
+ class ConvLayerInfo:
209
+ """
210
+ Data Class to store Conv layer's information
211
+ """
212
+ layer_number: int
213
+ weights: torch.nn.parameter.Parameter
214
+ layer_info: torch.nn.modules.conv.Conv2d
215
+
216
+
217
+ class FeatureMapVisualizer:
218
+ """
219
+ Class to visualize Feature Map of the Layers
220
+ """
221
+
222
+ def __init__(self, model):
223
+ """
224
+ Contructor
225
+ :param model: Model Architecture
226
+ """
227
+ self.conv_layers = []
228
+ self.outputs = []
229
+ self.layerwise_kernels = None
230
+
231
+ # Disect the model
232
+ counter = 0
233
+ model_children = model.children()
234
+ for children in model_children:
235
+ if type(children) == nn.Sequential:
236
+ for child in children:
237
+ if type(child) == nn.Conv2d:
238
+ counter += 1
239
+ self.conv_layers.append(ConvLayerInfo(layer_number=counter,
240
+ weights=child.weight,
241
+ layer_info=child)
242
+ )
243
+
244
+ def get_model_weights(self):
245
+ """
246
+ Method to get the model weights
247
+ """
248
+ model_weights = [layer.weights for layer in self.conv_layers]
249
+ return model_weights
250
+
251
+ def get_conv_layers(self):
252
+ """
253
+ Get the convolution layers
254
+ """
255
+ conv_layers = [layer.layer_info for layer in self.conv_layers]
256
+ return conv_layers
257
+
258
+ def get_total_conv_layers(self) -> int:
259
+ """
260
+ Get total number of convolution layers
261
+ """
262
+ out = self.get_conv_layers()
263
+ return len(out)
264
+
265
+ def feature_maps_of_all_kernels(self, image: torch.Tensor) -> dict:
266
+ """
267
+ Get feature maps from all the kernels of all the layers
268
+ :param image: Image to be passed to the network
269
+ """
270
+ image = image.unsqueeze(0)
271
+ image = image.to('cpu')
272
+
273
+ outputs = {}
274
+
275
+ layers = self.get_conv_layers()
276
+ for index, layer in enumerate(layers):
277
+ image = layer(image)
278
+ outputs[str(layer)] = image
279
+ self.outputs = outputs
280
+ return outputs
281
+
282
+ def visualize_feature_map_of_kernel(self, image: torch.Tensor, kernel_number: int) -> None:
283
+ """
284
+ Function to visualize feature map of kernel number from each layer
285
+ :param image: Image passed to the network
286
+ :param kernel_number: Number of kernel in each layer (Should be less than or equal to the minimum number of kernel in the network)
287
+ """
288
+ # List to store processed feature maps
289
+ processed = []
290
+
291
+ # Get feature maps from all kernels of all the conv layers
292
+ outputs = self.feature_maps_of_all_kernels(image)
293
+
294
+ # Extract the n_th kernel's output from each layer and convert it to grayscale
295
+ for feature_map in outputs.values():
296
+ try:
297
+ feature_map = feature_map[0][kernel_number]
298
+ except IndexError:
299
+ print("Filter number should be less than the minimum number of channels in a network")
300
+ break
301
+ finally:
302
+ gray_scale = feature_map / feature_map.shape[0]
303
+ processed.append(gray_scale.data.numpy())
304
+
305
+ # Plot the Feature maps with layer and kernel number
306
+ x_range = len(outputs) // 5 + 4
307
+ fig = plt.figure(figsize=(10, 10))
308
+ for i in range(len(processed)):
309
+ a = fig.add_subplot(x_range, 5, i + 1)
310
+ imgplot = plt.imshow(processed[i])
311
+ a.axis("off")
312
+ title = f"{list(outputs.keys())[i].split('(')[0]}_l{i + 1}_k{kernel_number}"
313
+ a.set_title(title, fontsize=10)
314
+ return fig
315
+
316
+ def get_max_kernel_number(self):
317
+ """
318
+ Function to get maximum number of kernels in the network (for a layer)
319
+ """
320
+ layers = self.get_conv_layers()
321
+ channels = [layer.out_channels for layer in layers]
322
+ self.layerwise_kernels = channels
323
+ return max(channels)
324
+
325
+ def visualize_kernels_from_layer(self, layer_number: int):
326
+ """
327
+ Visualize Kernels from a layer
328
+ :param layer_number: Number of layer from which kernels are to be visualized
329
+ """
330
+ # Get the kernels number for each layer
331
+ self.get_max_kernel_number()
332
+
333
+ # Zero Indexing
334
+ layer_number = layer_number - 1
335
+ _kernels = self.layerwise_kernels[layer_number]
336
+
337
+ grid = math.ceil(math.sqrt(_kernels))
338
+
339
+ fig = plt.figure(figsize=(5, 4))
340
+ model_weights = self.get_model_weights()
341
+ _layer_weights = model_weights[layer_number].cpu()
342
+ for i, filter in enumerate(_layer_weights):
343
+ plt.subplot(grid, grid, i + 1)
344
+ plt.imshow(filter[0, :, :].detach(), cmap='gray')
345
+ plt.axis('off')
346
+ # plt.show()
347
+ return fig
348
+
349
+
350
+ # ---------------------------- Confusion Matrix ----------------------------
351
+ def visualize_confusion_matrix(classes: list[str], device: str, model: 'DL Model',
352
+ test_loader: torch.utils.data.DataLoader):
353
+ """
354
+ Function to generate and visualize confusion matrix
355
+ :param classes: List of class names
356
+ :param device: cuda/cpu
357
+ :param model: Model Architecture
358
+ :param test_loader: DataLoader for test set
359
+ """
360
+ nb_classes = len(classes)
361
+ device = 'cuda'
362
+ cm = torch.zeros(nb_classes, nb_classes)
363
+
364
+ model.eval()
365
+ with torch.no_grad():
366
+ for inputs, labels in test_loader:
367
+ inputs = inputs.to(device)
368
+ labels = labels.to(device)
369
+ model = model.to(device)
370
+
371
+ preds = model(inputs)
372
+ preds = preds.argmax(dim=1)
373
+
374
+ for t, p in zip(labels.view(-1), preds.view(-1)):
375
+ cm[t, p] = cm[t, p] + 1
376
+
377
+ # Build confusion matrix
378
+ labels = labels.to('cpu')
379
+ preds = preds.to('cpu')
380
+ cf_matrix = confusion_matrix(labels, preds)
381
+ df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None],
382
+ index=[i for i in classes],
383
+ columns=[i for i in classes])
384
+ plt.figure(figsize=(12, 7))
385
+ sn.heatmap(df_cm, annot=True)
386
+
387
+ def get_summary(model: 'object of model architecture', input_size: tuple) -> NoReturn:
388
+ """
389
+ Function to get the summary of the model architecture
390
+ :param model: Object of model architecture class
391
+ :param input_size: Input data shape (Channels, Height, Width)
392
+ """
393
+ use_cuda = torch.cuda.is_available()
394
+ device = torch.device("cuda" if use_cuda else "cpu")
395
+ network = model.to(device)
396
+ summary(network, input_size=input_size)
397
+
398
+
399
+ def get_misclassified_data(model, device, test_loader):
400
+ """
401
+ Function to run the model on test set and return misclassified images
402
+ :param model: Network Architecture
403
+ :param device: CPU/GPU
404
+ :param test_loader: DataLoader for test set
405
+ """
406
+ # Prepare the model for evaluation i.e. drop the dropout layer
407
+ model.eval()
408
+
409
+ # List to store misclassified Images
410
+ misclassified_data = []
411
+
412
+ # Reset the gradients
413
+ with torch.no_grad():
414
+ # Extract images, labels in a batch
415
+ for data, target in test_loader:
416
+
417
+ # Migrate the data to the device
418
+ data, target = data.to(device), target.to(device)
419
+
420
+ # Extract single image, label from the batch
421
+ for image, label in zip(data, target):
422
+
423
+ # Add batch dimension to the image
424
+ image = image.unsqueeze(0)
425
+
426
+ # Get the model prediction on the image
427
+ output = model(image)
428
+
429
+ # Convert the output from one-hot encoding to a value
430
+ pred = output.argmax(dim=1, keepdim=True)
431
+
432
+ # If prediction is incorrect, append the data
433
+ if pred != label:
434
+ misclassified_data.append((image, label, pred))
435
+ return misclassified_data
436
+
437
+
438
+ # -------------------- DATA STATISTICS --------------------
439
+ def get_mnist_statistics(data_set, data_set_type='Train'):
440
+ """
441
+ Function to return the statistics of the training data
442
+ :param data_set: Training dataset
443
+ :param data_set_type: Type of dataset [Train/Test/Val]
444
+ """
445
+ # We'd need to convert it into Numpy! Remember above we have converted it into tensors already
446
+ train_data = data_set.train_data
447
+ train_data = data_set.transform(train_data.numpy())
448
+
449
+ print(f'[{data_set_type}]')
450
+ print(' - Numpy Shape:', data_set.train_data.cpu().numpy().shape)
451
+ print(' - Tensor Shape:', data_set.train_data.size())
452
+ print(' - min:', torch.min(train_data))
453
+ print(' - max:', torch.max(train_data))
454
+ print(' - mean:', torch.mean(train_data))
455
+ print(' - std:', torch.std(train_data))
456
+ print(' - var:', torch.var(train_data))
457
+
458
+ dataiter = next(iter(data_set))
459
+ images, labels = dataiter[0], dataiter[1]
460
+
461
+ print(images.shape)
462
+ print(labels)
463
+
464
+ # Let's visualize some of the images
465
+ plt.imshow(images[0].numpy().squeeze(), cmap='gray')
466
+
467
+
468
+ def get_cifar_property(images, operation):
469
+ """
470
+ Get the property on each channel of the CIFAR
471
+ :param images: Get the property value on the images
472
+ :param operation: Mean, std, Variance, etc
473
+ """
474
+ param_r = eval('images[:, 0, :, :].' + operation + '()')
475
+ param_g = eval('images[:, 1, :, :].' + operation + '()')
476
+ param_b = eval('images[:, 2, :, :].' + operation + '()')
477
+ return param_r, param_g, param_b
478
+
479
+
480
+ def get_cifar_statistics(data_set, data_set_type='Train'):
481
+ """
482
+ Function to get the statistical information of the CIFAR dataset
483
+ :param data_set: Training set of CIFAR
484
+ :param data_set_type: Training or Test data
485
+ """
486
+ # Images in the dataset
487
+ images = [item[0] for item in data_set]
488
+ images = torch.stack(images, dim=0).numpy()
489
+
490
+ # Calculate mean over each channel
491
+ mean_r, mean_g, mean_b = get_cifar_property(images, 'mean')
492
+
493
+ # Calculate Standard deviation over each channel
494
+ std_r, std_g, std_b = get_cifar_property(images, 'std')
495
+
496
+ # Calculate min value over each channel
497
+ min_r, min_g, min_b = get_cifar_property(images, 'min')
498
+
499
+ # Calculate max value over each channel
500
+ max_r, max_g, max_b = get_cifar_property(images, 'max')
501
+
502
+ # Calculate variance value over each channel
503
+ var_r, var_g, var_b = get_cifar_property(images, 'var')
504
+
505
+ print(f'[{data_set_type}]')
506
+ print(f' - Total {data_set_type} Images: {len(data_set)}')
507
+ print(f' - Tensor Shape: {images[0].shape}')
508
+ print(f' - min: {min_r, min_g, min_b}')
509
+ print(f' - max: {max_r, max_g, max_b}')
510
+ print(f' - mean: {mean_r, mean_g, mean_b}')
511
+ print(f' - std: {std_r, std_g, std_b}')
512
+ print(f' - var: {var_r, var_g, var_b}')
513
+
514
+ # Let's visualize some of the images
515
+ plt.imshow(np.transpose(images[1].squeeze(), (1, 2, 0)))
516
+
517
+
518
+ # -------------------- GradCam --------------------
519
+ def display_gradcam_output(data: list,
520
+ classes: list[str],
521
+ inv_normalize: transforms.Normalize,
522
+ model: 'DL Model',
523
+ target_layers: list['model_layer'],
524
+ targets=None,
525
+ number_of_samples: int = 10,
526
+ transparency: float = 0.60):
527
+ """
528
+ Function to visualize GradCam output on the data
529
+ :param data: List[Tuple(image, label)]
530
+ :param classes: Name of classes in the dataset
531
+ :param inv_normalize: Mean and Standard deviation values of the dataset
532
+ :param model: Model architecture
533
+ :param target_layers: Layers on which GradCam should be executed
534
+ :param targets: Classes to be focused on for GradCam
535
+ :param number_of_samples: Number of images to print
536
+ :param transparency: Weight of Normal image when mixed with activations
537
+ """
538
+ # Plot configuration
539
+ fig = plt.figure(figsize=(10, 10))
540
+ x_count = 5
541
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
542
+
543
+ # Create an object for GradCam
544
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
545
+
546
+ # Iterate over number of specified images
547
+ for i in range(number_of_samples):
548
+ plt.subplot(y_count, x_count, i + 1)
549
+ input_tensor = data[i][0]
550
+
551
+ # Get the activations of the layer for the images
552
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
553
+ grayscale_cam = grayscale_cam[0, :]
554
+
555
+ # Get back the original image
556
+ img = input_tensor.squeeze(0).to('cpu')
557
+ img = inv_normalize(img)
558
+ rgb_img = np.transpose(img, (1, 2, 0))
559
+ rgb_img = rgb_img.numpy()
560
+
561
+ # Mix the activations on the original image
562
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
563
+
564
+ # Display the images on the plot
565
+ plt.imshow(visualization)
566
+ plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()])
567
+ plt.xticks([])
568
+ plt.yticks([])