Vvaann commited on
Commit
9b0799a
·
verified ·
1 Parent(s): 696f4a2

Upload 6 files

Browse files
Files changed (6) hide show
  1. augmentations.py +56 -0
  2. datasets.py +38 -0
  3. resnet.py +76 -0
  4. training_utils.py +126 -0
  5. utils.py +201 -0
  6. visualize.py +384 -0
augmentations.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Function used for visualization of data and results
4
+ Author: Shilpaj Bhalerao
5
+ Date: Jul 23, 2023
6
+ """
7
+ # Third-Party Imports
8
+ import torch
9
+ import albumentations as A
10
+ from albumentations.pytorch import ToTensorV2
11
+
12
+
13
+ # Train Phase transformations
14
+ train_set_transforms = {
15
+ 'randomcrop': A.RandomCrop(height=32, width=32, p=0.2),
16
+ 'horizontalflip': A.HorizontalFlip(),
17
+ 'cutout': A.CoarseDropout(max_holes=1, max_height=16, max_width=16, min_holes=1, min_height=1, min_width=1, fill_value=[0.49139968*255, 0.48215827*255 ,0.44653124*255], mask_fill_value=None),
18
+ 'normalize': A.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
19
+ 'standardize': ToTensorV2(),
20
+ }
21
+
22
+ # Test Phase transformations
23
+ test_set_transforms = {
24
+ 'normalize': A.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
25
+ 'standardize': ToTensorV2()
26
+ }
27
+
28
+
29
+ class AddGaussianNoise(object):
30
+ """
31
+ Class for custom augmentation strategy
32
+ """
33
+ def __init__(self, mean=0., std=1.):
34
+ """
35
+ Constructor
36
+ """
37
+ self.std = std
38
+ self.mean = mean
39
+
40
+ def __call__(self, tensor):
41
+ """
42
+ Augmentation strategy to be implemented when called
43
+ """
44
+ return tensor + torch.randn(tensor.size()) * self.std + self.mean
45
+
46
+ def __repr__(self):
47
+ """
48
+ Method to print more infor about the strategy
49
+ """
50
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
51
+
52
+ # Usage details
53
+ # transforms = transforms.Compose([
54
+ # transforms.ToTensor(),
55
+ # AddGaussianNoise(0., 1.0),
56
+ # ])
datasets.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module containing wrapper classes for PyTorch Datasets
4
+ Author: Shilpaj Bhalerao
5
+ Date: Jun 25, 2023
6
+ """
7
+ # Standard Library Imports
8
+ from typing import Tuple
9
+
10
+ # Third-Party Imports
11
+ from torchvision import datasets, transforms
12
+
13
+
14
+ class AlbumDataset(datasets.CIFAR10):
15
+ """
16
+ Wrapper class to use albumentations library with PyTorch Dataset
17
+ """
18
+ def __init__(self, root: str = "./data", train: bool = True, download: bool = True, transform: list = None):
19
+ """
20
+ Constructor
21
+ :param root: Directory at which data is stored
22
+ :param train: Param to distinguish if data is training or test
23
+ :param download: Param to download the dataset from source
24
+ :param transform: List of transformation to be performed on the dataset
25
+ """
26
+ super().__init__(root=root, train=train, download=download, transform=transform)
27
+
28
+ def __getitem__(self, index: int) -> Tuple:
29
+ """
30
+ Method to return image and its label
31
+ :param index: Index of image and label in the dataset
32
+ """
33
+ image, label = self.data[index], self.targets[index]
34
+
35
+ if self.transform:
36
+ transformed = self.transform(image=image)
37
+ image = transformed["image"]
38
+ return image, label
resnet.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ResNet in PyTorch.
3
+ For Pre-activation ResNet, see 'preact_resnet.py'.
4
+
5
+ Reference:
6
+ [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
8
+ """
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class BasicBlock(nn.Module):
14
+ expansion = 1
15
+
16
+ def __init__(self, in_planes, planes, stride=1):
17
+ super(BasicBlock, self).__init__()
18
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
19
+ self.bn1 = nn.BatchNorm2d(planes)
20
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+
23
+ self.shortcut = nn.Sequential()
24
+ if stride != 1 or in_planes != self.expansion*planes:
25
+ self.shortcut = nn.Sequential(
26
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
27
+ nn.BatchNorm2d(self.expansion*planes)
28
+ )
29
+
30
+ def forward(self, x):
31
+ out = F.relu(self.bn1(self.conv1(x)))
32
+ out = self.bn2(self.conv2(out))
33
+ out += self.shortcut(x)
34
+ out = F.relu(out)
35
+ return out
36
+
37
+
38
+ class ResNet(nn.Module):
39
+ def __init__(self, block, num_blocks, num_classes=10):
40
+ super(ResNet, self).__init__()
41
+ self.in_planes = 64
42
+
43
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
44
+ self.bn1 = nn.BatchNorm2d(64)
45
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
46
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
47
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
48
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
49
+ self.linear = nn.Linear(512*block.expansion, num_classes)
50
+
51
+ def _make_layer(self, block, planes, num_blocks, stride):
52
+ strides = [stride] + [1]*(num_blocks-1)
53
+ layers = []
54
+ for stride in strides:
55
+ layers.append(block(self.in_planes, planes, stride))
56
+ self.in_planes = planes * block.expansion
57
+ return nn.Sequential(*layers)
58
+
59
+ def forward(self, x):
60
+ out = F.relu(self.bn1(self.conv1(x)))
61
+ out = self.layer1(out)
62
+ out = self.layer2(out)
63
+ out = self.layer3(out)
64
+ out = self.layer4(out)
65
+ out = F.avg_pool2d(out, 4)
66
+ out = out.view(out.size(0), -1)
67
+ out = self.linear(out)
68
+ return out
69
+
70
+
71
+ def ResNet18():
72
+ return ResNet(BasicBlock, [2, 2, 2, 2])
73
+
74
+
75
+ def ResNet34():
76
+ return ResNet(BasicBlock, [3, 4, 6, 3])
training_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Utilities for Model Training
4
+ Author: Shilpaj Bhalerao
5
+ Date: Jun 21, 2023
6
+ """
7
+ # Standard Library Imports
8
+
9
+ # Third-Party Imports
10
+ from tqdm import tqdm
11
+ import torch
12
+
13
+
14
+ def get_correct_predictions(prediction, labels):
15
+ """
16
+ Function to return total number of correct predictions
17
+ :param prediction: Model predictions on a given sample of data
18
+ :param labels: Correct labels of a given sample of data
19
+ :return: Number of correct predictions
20
+ """
21
+ return prediction.argmax(dim=1).eq(labels).sum().item()
22
+
23
+
24
+ def train(model, device, train_loader, optimizer, criterion, scheduler=None):
25
+ """
26
+ Function to train model on the training dataset
27
+ :param model: Model architecture
28
+ :param device: Device on which training is to be done (GPU/CPU)
29
+ :param train_loader: DataLoader for training dataset
30
+ :param optimizer: Optimization algorithm to be used for updating weights
31
+ :param criterion: Loss function for training
32
+ :param scheduler: Scheduler for learning rate
33
+ """
34
+ # Enable layers like Dropout for model training
35
+ model.train()
36
+
37
+ # Utility to display training progress
38
+ pbar = tqdm(train_loader)
39
+
40
+ # Variables to track loss and accuracy during training
41
+ train_loss = 0
42
+ correct = 0
43
+ processed = 0
44
+
45
+ # Iterate over each batch and fetch images and labels from the batch
46
+ for batch_idx, (data, target) in enumerate(pbar):
47
+
48
+ # Put the images and labels on the selected device
49
+ data, target = data.to(device), target.to(device)
50
+
51
+ # Reset the gradients for each batch
52
+ optimizer.zero_grad()
53
+
54
+ # Predict
55
+ pred = model(data)
56
+
57
+ # Calculate loss
58
+ loss = criterion(pred, target)
59
+ train_loss += loss.item()
60
+
61
+ # Backpropagation
62
+ loss.backward()
63
+ optimizer.step()
64
+
65
+ # Use learning rate scheduler if defined
66
+ if scheduler:
67
+ scheduler.step()
68
+
69
+ # Get total number of correct predictions
70
+ correct += get_correct_predictions(pred, target)
71
+ processed += len(data)
72
+
73
+ # Display the training information
74
+ pbar.set_description(
75
+ desc=f'Train: Loss={loss.item():0.4f} Batch_id={batch_idx} Accuracy={100 * correct / processed:0.2f}')
76
+
77
+ return correct, processed, train_loss
78
+
79
+
80
+ def test(model, device, test_loader, criterion):
81
+ """
82
+ Function to test the model training progress on the test dataset
83
+ :param model: Model architecture
84
+ :param device: Device on which training is to be done (GPU/CPU)
85
+ :param test_loader: DataLoader for test dataset
86
+ :param criterion: Loss function for test dataset
87
+ """
88
+ # Disable layers like Dropout for model inference
89
+ model.eval()
90
+
91
+ # Variables to track loss and accuracy
92
+ test_loss = 0
93
+ correct = 0
94
+
95
+ # Disable gradient updation
96
+ with torch.no_grad():
97
+ # Iterate over each batch and fetch images and labels from the batch
98
+ for batch_idx, (data, target) in enumerate(test_loader):
99
+
100
+ # Put the images and labels on the selected device
101
+ data, target = data.to(device), target.to(device)
102
+
103
+ # Pass the images to the output and get the model predictions
104
+ output = model(data)
105
+ test_loss += criterion(output, target).item() # sum up batch loss
106
+
107
+ # Sum up batch correct predictions
108
+ correct += get_correct_predictions(output, target)
109
+
110
+ # Calculate test loss for a epoch
111
+ test_loss /= len(test_loader.dataset)
112
+
113
+ print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
114
+ test_loss, correct, len(test_loader.dataset),
115
+ 100. * correct / len(test_loader.dataset)))
116
+
117
+ return correct, test_loss
118
+
119
+
120
+ def get_lr(optimizer):
121
+ """
122
+ Function to track learning rate while model training
123
+ :param optimizer: Optimizer used for training
124
+ """
125
+ for param_group in optimizer.param_groups:
126
+ return param_group['lr']
utils.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Utility Script containing functions to be used for training
4
+ Author: Shilpaj Bhalerao
5
+ """
6
+ # Standard Library Imports
7
+ import math
8
+ from typing import NoReturn
9
+
10
+ # Third-Party Imports
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ import torch
14
+ from torchsummary import summary
15
+ from torchvision import transforms
16
+ from pytorch_grad_cam import GradCAM
17
+ from pytorch_grad_cam.utils.image import show_cam_on_image
18
+
19
+
20
+ def get_summary(model: 'object of model architecture', input_size: tuple) -> NoReturn:
21
+ """
22
+ Function to get the summary of the model architecture
23
+ :param model: Object of model architecture class
24
+ :param input_size: Input data shape (Channels, Height, Width)
25
+ """
26
+ use_cuda = torch.cuda.is_available()
27
+ device = torch.device("cuda" if use_cuda else "cpu")
28
+ network = model.to(device)
29
+ summary(network, input_size=input_size)
30
+
31
+
32
+ def get_misclassified_data(model, device, test_loader):
33
+ """
34
+ Function to run the model on test set and return misclassified images
35
+ :param model: Network Architecture
36
+ :param device: CPU/GPU
37
+ :param test_loader: DataLoader for test set
38
+ """
39
+ # Prepare the model for evaluation i.e. drop the dropout layer
40
+ model.eval()
41
+
42
+ # List to store misclassified Images
43
+ misclassified_data = []
44
+
45
+ # Reset the gradients
46
+ with torch.no_grad():
47
+ # Extract images, labels in a batch
48
+ for data, target in test_loader:
49
+
50
+ # Migrate the data to the device
51
+ data, target = data.to(device), target.to(device)
52
+
53
+ # Extract single image, label from the batch
54
+ for image, label in zip(data, target):
55
+
56
+ # Add batch dimension to the image
57
+ image = image.unsqueeze(0)
58
+
59
+ # Get the model prediction on the image
60
+ output = model(image)
61
+
62
+ # Convert the output from one-hot encoding to a value
63
+ pred = output.argmax(dim=1, keepdim=True)
64
+
65
+ # If prediction is incorrect, append the data
66
+ if pred != label:
67
+ misclassified_data.append((image, label, pred))
68
+ return misclassified_data
69
+
70
+
71
+ # -------------------- DATA STATISTICS --------------------
72
+ def get_mnist_statistics(data_set, data_set_type='Train'):
73
+ """
74
+ Function to return the statistics of the training data
75
+ :param data_set: Training dataset
76
+ :param data_set_type: Type of dataset [Train/Test/Val]
77
+ """
78
+ # We'd need to convert it into Numpy! Remember above we have converted it into tensors already
79
+ train_data = data_set.train_data
80
+ train_data = data_set.transform(train_data.numpy())
81
+
82
+ print(f'[{data_set_type}]')
83
+ print(' - Numpy Shape:', data_set.train_data.cpu().numpy().shape)
84
+ print(' - Tensor Shape:', data_set.train_data.size())
85
+ print(' - min:', torch.min(train_data))
86
+ print(' - max:', torch.max(train_data))
87
+ print(' - mean:', torch.mean(train_data))
88
+ print(' - std:', torch.std(train_data))
89
+ print(' - var:', torch.var(train_data))
90
+
91
+ dataiter = next(iter(data_set))
92
+ images, labels = dataiter[0], dataiter[1]
93
+
94
+ print(images.shape)
95
+ print(labels)
96
+
97
+ # Let's visualize some of the images
98
+ plt.imshow(images[0].numpy().squeeze(), cmap='gray')
99
+
100
+
101
+ def get_cifar_property(images, operation):
102
+ """
103
+ Get the property on each channel of the CIFAR
104
+ :param images: Get the property value on the images
105
+ :param operation: Mean, std, Variance, etc
106
+ """
107
+ param_r = eval('images[:, 0, :, :].' + operation + '()')
108
+ param_g = eval('images[:, 1, :, :].' + operation + '()')
109
+ param_b = eval('images[:, 2, :, :].' + operation + '()')
110
+ return param_r, param_g, param_b
111
+
112
+
113
+ def get_cifar_statistics(data_set, data_set_type='Train'):
114
+ """
115
+ Function to get the statistical information of the CIFAR dataset
116
+ :param data_set: Training set of CIFAR
117
+ :param data_set_type: Training or Test data
118
+ """
119
+ # Images in the dataset
120
+ images = [item[0] for item in data_set]
121
+ images = torch.stack(images, dim=0).numpy()
122
+
123
+ # Calculate mean over each channel
124
+ mean_r, mean_g, mean_b = get_cifar_property(images, 'mean')
125
+
126
+ # Calculate Standard deviation over each channel
127
+ std_r, std_g, std_b = get_cifar_property(images, 'std')
128
+
129
+ # Calculate min value over each channel
130
+ min_r, min_g, min_b = get_cifar_property(images, 'min')
131
+
132
+ # Calculate max value over each channel
133
+ max_r, max_g, max_b = get_cifar_property(images, 'max')
134
+
135
+ # Calculate variance value over each channel
136
+ var_r, var_g, var_b = get_cifar_property(images, 'var')
137
+
138
+ print(f'[{data_set_type}]')
139
+ print(f' - Total {data_set_type} Images: {len(data_set)}')
140
+ print(f' - Tensor Shape: {images[0].shape}')
141
+ print(f' - min: {min_r, min_g, min_b}')
142
+ print(f' - max: {max_r, max_g, max_b}')
143
+ print(f' - mean: {mean_r, mean_g, mean_b}')
144
+ print(f' - std: {std_r, std_g, std_b}')
145
+ print(f' - var: {var_r, var_g, var_b}')
146
+
147
+ # Let's visualize some of the images
148
+ plt.imshow(np.transpose(images[1].squeeze(), (1, 2, 0)))
149
+
150
+
151
+ # -------------------- GradCam --------------------
152
+ def display_gradcam_output(data: list,
153
+ classes: list[str],
154
+ inv_normalize: transforms.Normalize,
155
+ model: 'DL Model',
156
+ target_layers: list['model_layer'],
157
+ targets=None,
158
+ number_of_samples: int = 10,
159
+ transparency: float = 0.60):
160
+ """
161
+ Function to visualize GradCam output on the data
162
+ :param data: List[Tuple(image, label)]
163
+ :param classes: Name of classes in the dataset
164
+ :param inv_normalize: Mean and Standard deviation values of the dataset
165
+ :param model: Model architecture
166
+ :param target_layers: Layers on which GradCam should be executed
167
+ :param targets: Classes to be focused on for GradCam
168
+ :param number_of_samples: Number of images to print
169
+ :param transparency: Weight of Normal image when mixed with activations
170
+ """
171
+ # Plot configuration
172
+ fig = plt.figure(figsize=(10, 10))
173
+ x_count = 5
174
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
175
+
176
+ # Create an object for GradCam
177
+ cam = GradCAM(model=model, target_layers=target_layers)
178
+
179
+ # Iterate over number of specified images
180
+ for i in range(number_of_samples):
181
+ plt.subplot(y_count, x_count, i + 1)
182
+ input_tensor = data[i][0]
183
+
184
+ # Get the activations of the layer for the images
185
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
186
+ grayscale_cam = grayscale_cam[0, :]
187
+
188
+ # Get back the original image
189
+ img = input_tensor.squeeze(0).to('cpu')
190
+ img = inv_normalize(img)
191
+ rgb_img = np.transpose(img, (1, 2, 0))
192
+ rgb_img = rgb_img.numpy()
193
+
194
+ # Mix the activations on the original image
195
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
196
+
197
+ # Display the images on the plot
198
+ plt.imshow(visualization)
199
+ plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()])
200
+ plt.xticks([])
201
+ plt.yticks([])
visualize.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Function used for visualization of data and results
4
+ Author: Shilpaj Bhalerao
5
+ Date: Jun 21, 2023
6
+ """
7
+ # Standard Library Imports
8
+ import math
9
+ from dataclasses import dataclass
10
+ from typing import NoReturn
11
+
12
+ # Third-Party Imports
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import pandas as pd
16
+ import seaborn as sn
17
+ import torch
18
+ import torch.nn as nn
19
+ from torchvision import transforms
20
+ from sklearn.metrics import confusion_matrix
21
+
22
+
23
+ # ---------------------------- DATA SAMPLES ----------------------------
24
+ def display_mnist_data_samples(dataset: 'DataLoader object', number_of_samples: int) -> NoReturn:
25
+ """
26
+ Function to display samples for dataloader
27
+ :param dataset: Train or Test dataset transformed to Tensor
28
+ :param number_of_samples: Number of samples to be displayed
29
+ """
30
+ # Get batch from the data_set
31
+ batch_data = []
32
+ batch_label = []
33
+ for count, item in enumerate(dataset):
34
+ if not count <= number_of_samples:
35
+ break
36
+ batch_data.append(item[0])
37
+ batch_label.append(item[1])
38
+
39
+ # Plot the samples from the batch
40
+ fig = plt.figure()
41
+ x_count = 5
42
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
43
+
44
+ # Plot the samples from the batch
45
+ for i in range(number_of_samples):
46
+ plt.subplot(y_count, x_count, i + 1)
47
+ plt.tight_layout()
48
+ plt.imshow(batch_data[i].squeeze(), cmap='gray')
49
+ plt.title(batch_label[i])
50
+ plt.xticks([])
51
+ plt.yticks([])
52
+
53
+
54
+ def display_cifar_data_samples(data_set, number_of_samples: int, classes: list):
55
+ """
56
+ Function to display samples for data_set
57
+ :param data_set: Train or Test data_set transformed to Tensor
58
+ :param number_of_samples: Number of samples to be displayed
59
+ :param classes: Name of classes to be displayed
60
+ """
61
+ # Get batch from the data_set
62
+ batch_data = []
63
+ batch_label = []
64
+ for count, item in enumerate(data_set):
65
+ if not count <= number_of_samples:
66
+ break
67
+ batch_data.append(item[0])
68
+ batch_label.append(item[1])
69
+ batch_data = torch.stack(batch_data, dim=0).numpy()
70
+
71
+ # Plot the samples from the batch
72
+ fig = plt.figure()
73
+ x_count = 5
74
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
75
+
76
+ for i in range(number_of_samples):
77
+ plt.subplot(y_count, x_count, i + 1)
78
+ plt.tight_layout()
79
+ plt.imshow(np.transpose(batch_data[i].squeeze(), (1, 2, 0)))
80
+ plt.title(classes[batch_label[i]])
81
+ plt.xticks([])
82
+ plt.yticks([])
83
+
84
+
85
+ # ---------------------------- MISCLASSIFIED DATA ----------------------------
86
+ def display_cifar_misclassified_data(data: list,
87
+ classes: list[str],
88
+ inv_normalize: transforms.Normalize,
89
+ number_of_samples: int = 10):
90
+ """
91
+ Function to plot images with labels
92
+ :param data: List[Tuple(image, label)]
93
+ :param classes: Name of classes in the dataset
94
+ :param inv_normalize: Mean and Standard deviation values of the dataset
95
+ :param number_of_samples: Number of images to print
96
+ """
97
+ fig = plt.figure(figsize=(10, 10))
98
+
99
+ x_count = 5
100
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
101
+
102
+ for i in range(number_of_samples):
103
+ plt.subplot(y_count, x_count, i + 1)
104
+ img = data[i][0].squeeze().to('cpu')
105
+ img = inv_normalize(img)
106
+ plt.imshow(np.transpose(img, (1, 2, 0)))
107
+ plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()])
108
+ plt.xticks([])
109
+ plt.yticks([])
110
+
111
+
112
+ def display_mnist_misclassified_data(data: list,
113
+ number_of_samples: int = 10):
114
+ """
115
+ Function to plot images with labels
116
+ :param data: List[Tuple(image, label)]
117
+ :param number_of_samples: Number of images to print
118
+ """
119
+ fig = plt.figure(figsize=(8, 5))
120
+
121
+ x_count = 5
122
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
123
+
124
+ for i in range(number_of_samples):
125
+ plt.subplot(y_count, x_count, i + 1)
126
+ img = data[i][0].squeeze(0).to('cpu')
127
+ plt.imshow(np.transpose(img, (1, 2, 0)), cmap='gray')
128
+ plt.title(r"Correct: " + str(data[i][1].item()) + '\n' + 'Output: ' + str(data[i][2].item()))
129
+ plt.xticks([])
130
+ plt.yticks([])
131
+
132
+
133
+ # ---------------------------- AUGMENTATION SAMPLES ----------------------------
134
+ def visualize_cifar_augmentation(data_set, data_transforms):
135
+ """
136
+ Function to visualize the augmented data
137
+ :param data_set: Dataset without transformations
138
+ :param data_transforms: Dictionary of transforms
139
+ """
140
+ sample, label = data_set[6]
141
+ total_augmentations = len(data_transforms)
142
+
143
+ fig = plt.figure(figsize=(10, 5))
144
+ for count, (key, trans) in enumerate(data_transforms.items()):
145
+ if count == total_augmentations - 1:
146
+ break
147
+ plt.subplot(math.ceil(total_augmentations / 5), 5, count + 1)
148
+ augmented = trans(image=sample)['image']
149
+ plt.imshow(augmented)
150
+ plt.title(key)
151
+ plt.xticks([])
152
+ plt.yticks([])
153
+
154
+
155
+ def visualize_mnist_augmentation(data_set, data_transforms):
156
+ """
157
+ Function to visualize the augmented data
158
+ :param data_set: Dataset to visualize the augmentations
159
+ :param data_transforms: Dictionary of transforms
160
+ """
161
+ sample, label = data_set[6]
162
+ total_augmentations = len(data_transforms)
163
+
164
+ fig = plt.figure(figsize=(10, 5))
165
+ for count, (key, trans) in enumerate(data_transforms.items()):
166
+ if count == total_augmentations - 1:
167
+ break
168
+ plt.subplot(math.ceil(total_augmentations / 5), 5, count + 1)
169
+ img = trans(sample).to('cpu')
170
+ plt.imshow(np.transpose(img, (1, 2, 0)), cmap='gray')
171
+ plt.title(key)
172
+ plt.xticks([])
173
+ plt.yticks([])
174
+
175
+
176
+ # ---------------------------- LOSS AND ACCURACIES ----------------------------
177
+ def display_loss_and_accuracies(train_losses: list,
178
+ train_acc: list,
179
+ test_losses: list,
180
+ test_acc: list,
181
+ plot_size: tuple = (10, 10)) -> NoReturn:
182
+ """
183
+ Function to display training and test information(losses and accuracies)
184
+ :param train_losses: List containing training loss of each epoch
185
+ :param train_acc: List containing training accuracy of each epoch
186
+ :param test_losses: List containing test loss of each epoch
187
+ :param test_acc: List containing test accuracy of each epoch
188
+ :param plot_size: Size of the plot
189
+ """
190
+ # Create a plot of 2x2 of size
191
+ fig, axs = plt.subplots(2, 2, figsize=plot_size)
192
+
193
+ # Plot the training loss and accuracy for each epoch
194
+ axs[0, 0].plot(train_losses)
195
+ axs[0, 0].set_title("Training Loss")
196
+ axs[1, 0].plot(train_acc)
197
+ axs[1, 0].set_title("Training Accuracy")
198
+
199
+ # Plot the test loss and accuracy for each epoch
200
+ axs[0, 1].plot(test_losses)
201
+ axs[0, 1].set_title("Test Loss")
202
+ axs[1, 1].plot(test_acc)
203
+ axs[1, 1].set_title("Test Accuracy")
204
+
205
+
206
+ # ---------------------------- Feature Maps and Kernels ----------------------------
207
+
208
+ @dataclass
209
+ class ConvLayerInfo:
210
+ """
211
+ Data Class to store Conv layer's information
212
+ """
213
+ layer_number: int
214
+ weights: torch.nn.parameter.Parameter
215
+ layer_info: torch.nn.modules.conv.Conv2d
216
+
217
+
218
+ class FeatureMapVisualizer:
219
+ """
220
+ Class to visualize Feature Map of the Layers
221
+ """
222
+
223
+ def __init__(self, model):
224
+ """
225
+ Contructor
226
+ :param model: Model Architecture
227
+ """
228
+ self.conv_layers = []
229
+ self.outputs = []
230
+ self.layerwise_kernels = None
231
+
232
+ # Disect the model
233
+ counter = 0
234
+ model_children = model.children()
235
+ for children in model_children:
236
+ if type(children) == nn.Sequential:
237
+ for child in children:
238
+ if type(child) == nn.Conv2d:
239
+ counter += 1
240
+ self.conv_layers.append(ConvLayerInfo(layer_number=counter,
241
+ weights=child.weight,
242
+ layer_info=child)
243
+ )
244
+
245
+ def get_model_weights(self):
246
+ """
247
+ Method to get the model weights
248
+ """
249
+ model_weights = [layer.weights for layer in self.conv_layers]
250
+ return model_weights
251
+
252
+ def get_conv_layers(self):
253
+ """
254
+ Get the convolution layers
255
+ """
256
+ conv_layers = [layer.layer_info for layer in self.conv_layers]
257
+ return conv_layers
258
+
259
+ def get_total_conv_layers(self) -> int:
260
+ """
261
+ Get total number of convolution layers
262
+ """
263
+ out = self.get_conv_layers()
264
+ return len(out)
265
+
266
+ def feature_maps_of_all_kernels(self, image: torch.Tensor) -> dict:
267
+ """
268
+ Get feature maps from all the kernels of all the layers
269
+ :param image: Image to be passed to the network
270
+ """
271
+ image = image.unsqueeze(0)
272
+ image = image.to('cpu')
273
+
274
+ outputs = {}
275
+
276
+ layers = self.get_conv_layers()
277
+ for index, layer in enumerate(layers):
278
+ image = layer(image)
279
+ outputs[str(layer)] = image
280
+ self.outputs = outputs
281
+ return outputs
282
+
283
+ def visualize_feature_map_of_kernel(self, image: torch.Tensor, kernel_number: int) -> None:
284
+ """
285
+ Function to visualize feature map of kernel number from each layer
286
+ :param image: Image passed to the network
287
+ :param kernel_number: Number of kernel in each layer (Should be less than or equal to the minimum number of kernel in the network)
288
+ """
289
+ # List to store processed feature maps
290
+ processed = []
291
+
292
+ # Get feature maps from all kernels of all the conv layers
293
+ outputs = self.feature_maps_of_all_kernels(image)
294
+
295
+ # Extract the n_th kernel's output from each layer and convert it to grayscale
296
+ for feature_map in outputs.values():
297
+ try:
298
+ feature_map = feature_map[0][kernel_number]
299
+ except IndexError:
300
+ print("Filter number should be less than the minimum number of channels in a network")
301
+ break
302
+ finally:
303
+ gray_scale = feature_map / feature_map.shape[0]
304
+ processed.append(gray_scale.data.numpy())
305
+
306
+ # Plot the Feature maps with layer and kernel number
307
+ x_range = len(outputs) // 5 + 4
308
+ fig = plt.figure(figsize=(10, 10))
309
+ for i in range(len(processed)):
310
+ a = fig.add_subplot(x_range, 5, i + 1)
311
+ imgplot = plt.imshow(processed[i])
312
+ a.axis("off")
313
+ title = f"{list(outputs.keys())[i].split('(')[0]}_l{i + 1}_k{kernel_number}"
314
+ a.set_title(title, fontsize=10)
315
+
316
+ def get_max_kernel_number(self):
317
+ """
318
+ Function to get maximum number of kernels in the network (for a layer)
319
+ """
320
+ layers = self.get_conv_layers()
321
+ channels = [layer.out_channels for layer in layers]
322
+ self.layerwise_kernels = channels
323
+ return max(channels)
324
+
325
+ def visualize_kernels_from_layer(self, layer_number: int):
326
+ """
327
+ Visualize Kernels from a layer
328
+ :param layer_number: Number of layer from which kernels are to be visualized
329
+ """
330
+ # Get the kernels number for each layer
331
+ self.get_max_kernel_number()
332
+
333
+ # Zero Indexing
334
+ layer_number = layer_number - 1
335
+ _kernels = self.layerwise_kernels[layer_number]
336
+
337
+ grid = math.ceil(math.sqrt(_kernels))
338
+
339
+ plt.figure(figsize=(5, 4))
340
+ model_weights = self.get_model_weights()
341
+ _layer_weights = model_weights[layer_number].cpu()
342
+ for i, filter in enumerate(_layer_weights):
343
+ plt.subplot(grid, grid, i + 1)
344
+ plt.imshow(filter[0, :, :].detach(), cmap='gray')
345
+ plt.axis('off')
346
+ plt.show()
347
+
348
+
349
+ # ---------------------------- Confusion Matrix ----------------------------
350
+ def visualize_confusion_matrix(classes: list[str], device: str, model: 'DL Model',
351
+ test_loader: torch.utils.data.DataLoader):
352
+ """
353
+ Function to generate and visualize confusion matrix
354
+ :param classes: List of class names
355
+ :param device: cuda/cpu
356
+ :param model: Model Architecture
357
+ :param test_loader: DataLoader for test set
358
+ """
359
+ nb_classes = len(classes)
360
+ device = 'cuda'
361
+ cm = torch.zeros(nb_classes, nb_classes)
362
+
363
+ model.eval()
364
+ with torch.no_grad():
365
+ for inputs, labels in test_loader:
366
+ inputs = inputs.to(device)
367
+ labels = labels.to(device)
368
+ model = model.to(device)
369
+
370
+ preds = model(inputs)
371
+ preds = preds.argmax(dim=1)
372
+
373
+ for t, p in zip(labels.view(-1), preds.view(-1)):
374
+ cm[t, p] = cm[t, p] + 1
375
+
376
+ # Build confusion matrix
377
+ labels = labels.to('cpu')
378
+ preds = preds.to('cpu')
379
+ cf_matrix = confusion_matrix(labels, preds)
380
+ df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None],
381
+ index=[i for i in classes],
382
+ columns=[i for i in classes])
383
+ plt.figure(figsize=(12, 7))
384
+ sn.heatmap(df_cm, annot=True)