Navyabhat commited on
Commit
70b9a35
·
1 Parent(s): 33d956e

Upload 8 files

Browse files
models/custom_resnet.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ResBlock(nn.Module):
6
+ def __init__(self, channels):
7
+ super(ResBlock, self).__init__()
8
+
9
+ self.resblock = nn.Sequential(
10
+ nn.Conv2d(
11
+ in_channels=channels,
12
+ out_channels=channels,
13
+ kernel_size=3,
14
+ stride=1,
15
+ padding=1,
16
+ bias=False,
17
+ ),
18
+ nn.BatchNorm2d(channels),
19
+ nn.ReLU(),
20
+ nn.Conv2d(
21
+ in_channels=channels,
22
+ out_channels=channels,
23
+ kernel_size=3,
24
+ stride=1,
25
+ padding=1,
26
+ bias=False,
27
+ ),
28
+ nn.BatchNorm2d(channels),
29
+ nn.ReLU(),
30
+ )
31
+
32
+ def forward(self, x):
33
+ return x + self.resblock(x)
34
+
35
+
36
+ class CustomResnet(nn.Module):
37
+ def __init__(self):
38
+ super(CustomResnet, self).__init__()
39
+
40
+ self.prep = nn.Sequential(
41
+ nn.Conv2d(
42
+ in_channels=3,
43
+ out_channels=64,
44
+ kernel_size=3,
45
+ stride=1,
46
+ padding=1,
47
+ bias=False,
48
+ ),
49
+ nn.BatchNorm2d(64),
50
+ nn.ReLU(),
51
+ )
52
+
53
+ self.layer1 = nn.Sequential(
54
+ nn.Conv2d(
55
+ in_channels=64,
56
+ out_channels=128,
57
+ kernel_size=3,
58
+ padding=1,
59
+ stride=1,
60
+ bias=False,
61
+ ),
62
+ nn.MaxPool2d(kernel_size=2),
63
+ nn.BatchNorm2d(128),
64
+ nn.ReLU(),
65
+ ResBlock(channels=128),
66
+ )
67
+
68
+ self.layer2 = nn.Sequential(
69
+ nn.Conv2d(
70
+ in_channels=128,
71
+ out_channels=256,
72
+ kernel_size=3,
73
+ padding=1,
74
+ stride=1,
75
+ bias=False,
76
+ ),
77
+ nn.MaxPool2d(kernel_size=2),
78
+ nn.BatchNorm2d(256),
79
+ nn.ReLU(),
80
+ )
81
+
82
+ self.layer3 = nn.Sequential(
83
+ nn.Conv2d(
84
+ in_channels=256,
85
+ out_channels=512,
86
+ kernel_size=3,
87
+ padding=1,
88
+ stride=1,
89
+ bias=False,
90
+ ),
91
+ nn.MaxPool2d(kernel_size=2),
92
+ nn.BatchNorm2d(512),
93
+ nn.ReLU(),
94
+ ResBlock(channels=512),
95
+ )
96
+
97
+ self.pool = nn.MaxPool2d(kernel_size=4)
98
+
99
+ self.fc = nn.Linear(in_features=512, out_features=10, bias=False)
100
+
101
+ self.softmax = nn.Softmax(dim=-1)
102
+
103
+ def forward(self, x):
104
+ x = self.prep(x)
105
+ x = self.layer1(x)
106
+ x = self.layer2(x)
107
+ x = self.layer3(x)
108
+ x = self.pool(x)
109
+ x = x.view(-1, 512)
110
+ x = self.fc(x)
111
+ # x = self.softmax(x)
112
+ return x
models/resnet_lightning.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import lightning as L
4
+ from torchmetrics import Accuracy
5
+ from typing import Any
6
+
7
+ from utils.common import one_cycle_lr
8
+
9
+ class ResidualBlock(L.LightningModule):
10
+ def __init__(self, channels):
11
+ super(ResidualBlock, self).__init__()
12
+
13
+ self.residual_block = nn.Sequential(
14
+ nn.Conv2d(
15
+ in_channels=channels,
16
+ out_channels=channels,
17
+ kernel_size=3,
18
+ stride=1,
19
+ padding=1,
20
+ bias=False,
21
+ ),
22
+ nn.BatchNorm2d(channels),
23
+ nn.ReLU(),
24
+ nn.Conv2d(
25
+ in_channels=channels,
26
+ out_channels=channels,
27
+ kernel_size=3,
28
+ stride=1,
29
+ padding=1,
30
+ bias=False,
31
+ ),
32
+ nn.BatchNorm2d(channels),
33
+ nn.ReLU(),
34
+ )
35
+
36
+ def forward(self, x):
37
+ return x + self.residual_block(x)
38
+
39
+ class ResNet(L.LightningModule):
40
+ def __init__(
41
+ self, batch_size=512, shuffle=True, num_workers=4, learning_rate=0.003, scheduler_steps=None, maxlr=None, epochs=None
42
+ ):
43
+ super(ResNet, self).__init__()
44
+ self.data_dir = "./data"
45
+ self.batch_size = batch_size
46
+ self.shuffle = shuffle
47
+ self.num_workers = num_workers
48
+ self.learning_rate = learning_rate
49
+ self.scheduler_steps = scheduler_steps
50
+ self.maxlr = maxlr if maxlr is not None else learning_rate
51
+ self.epochs = epochs
52
+
53
+ self.prep = nn.Sequential(
54
+ nn.Conv2d(
55
+ in_channels=3,
56
+ out_channels=64,
57
+ kernel_size=3,
58
+ stride=1,
59
+ padding=1,
60
+ bias=False,
61
+ ),
62
+ nn.BatchNorm2d(64),
63
+ nn.ReLU(),
64
+ )
65
+
66
+ self.layer1 = nn.Sequential(
67
+ nn.Conv2d(
68
+ in_channels=64,
69
+ out_channels=128,
70
+ kernel_size=3,
71
+ padding=1,
72
+ stride=1,
73
+ bias=False,
74
+ ),
75
+ nn.MaxPool2d(kernel_size=2),
76
+ nn.BatchNorm2d(128),
77
+ nn.ReLU(),
78
+ ResidualBlock(channels=128),
79
+ )
80
+
81
+ self.layer2 = nn.Sequential(
82
+ nn.Conv2d(
83
+ in_channels=128,
84
+ out_channels=256,
85
+ kernel_size=3,
86
+ padding=1,
87
+ stride=1,
88
+ bias=False,
89
+ ),
90
+ nn.MaxPool2d(kernel_size=2),
91
+ nn.BatchNorm2d(256),
92
+ nn.ReLU(),
93
+ )
94
+
95
+ self.layer3 = nn.Sequential(
96
+ nn.Conv2d(
97
+ in_channels=256,
98
+ out_channels=512,
99
+ kernel_size=3,
100
+ padding=1,
101
+ stride=1,
102
+ bias=False,
103
+ ),
104
+ nn.MaxPool2d(kernel_size=2),
105
+ nn.BatchNorm2d(512),
106
+ nn.ReLU(),
107
+ ResidualBlock(channels=512),
108
+ )
109
+
110
+ self.pool = nn.MaxPool2d(kernel_size=4)
111
+
112
+ self.fc = nn.Linear(in_features=512, out_features=10, bias=False)
113
+
114
+ self.softmax = nn.Softmax(dim=-1)
115
+
116
+ self.accuracy = Accuracy(task="multiclass", num_classes=10)
117
+
118
+ def forward(self, x):
119
+ x = self.prep(x)
120
+ x = self.layer1(x)
121
+ x = self.layer2(x)
122
+ x = self.layer3(x)
123
+ x = self.pool(x)
124
+ x = x.view(-1, 512)
125
+ x = self.fc(x)
126
+ # x = self.softmax(x)
127
+ return x
128
+
129
+ def configure_optimizers(self) -> Any:
130
+ optimizer = torch.optim.Adam(
131
+ self.parameters(), lr=self.learning_rate, weight_decay=1e-4
132
+ )
133
+ scheduler = one_cycle_lr(
134
+ optimizer=optimizer, maxlr=self.maxlr, steps=self.scheduler_steps, epochs=self.epochs
135
+ )
136
+ return {"optimizer": optimizer,
137
+ "lr_scheduler": {"scheduler": scheduler,
138
+ "interval": "step"}}
139
+
140
+ def training_step(self, batch, batch_idx):
141
+ X, y = batch
142
+ y_pred = self(X)
143
+ loss = nn.CrossEntropyLoss()(y_pred, y)
144
+
145
+ preds = torch.argmax(y_pred, dim=1)
146
+
147
+ accuracy = self.accuracy(preds, y)
148
+
149
+ self.log_dict({"train_loss": loss, "train_acc": accuracy}, prog_bar=True)
150
+ return loss
151
+
152
+ def validation_step(self, batch, batch_idx):
153
+ X, y = batch
154
+ y_pred = self(X)
155
+ loss = nn.CrossEntropyLoss(reduction="sum")(y_pred, y)
156
+
157
+ preds = torch.argmax(y_pred, dim=1)
158
+
159
+ accuracy = self.accuracy(preds, y)
160
+
161
+ self.log_dict({"val_loss": loss, "val_acc": accuracy}, prog_bar=True)
162
+
163
+ return loss
164
+
165
+ def test_step(self, batch, batch_idx):
166
+ X, y = batch
167
+ y_pred = self(X)
168
+ loss = nn.CrossEntropyLoss(reduction="sum")(y_pred, y)
169
+ preds = torch.argmax(y_pred, dim=1)
170
+
171
+ accuracy = self.accuracy(preds, y)
172
+
173
+ self.log_dict({"test_loss": loss, "test_acc": accuracy}, prog_bar=True)
utils/common.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import matplotlib.pyplot as plt
4
+
5
+ import torch
6
+ import torchvision
7
+ from torchinfo import summary
8
+ from torch_lr_finder import LRFinder
9
+
10
+
11
+ def find_lr(model, optimizer, criterion, device, trainloader, numiter, startlr, endlr):
12
+ lr_finder = LRFinder(
13
+ model=model, optimizer=optimizer, criterion=criterion, device=device
14
+ )
15
+
16
+ lr_finder.range_test(
17
+ train_loader=trainloader,
18
+ start_lr=startlr,
19
+ end_lr=endlr,
20
+ num_iter=numiter,
21
+ step_mode="exp",
22
+ )
23
+
24
+ lr_finder.plot()
25
+
26
+ lr_finder.reset()
27
+
28
+
29
+ def one_cycle_lr(optimizer, maxlr, steps, epochs):
30
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
31
+ optimizer=optimizer,
32
+ max_lr=maxlr,
33
+ steps_per_epoch=steps,
34
+ epochs=epochs,
35
+ pct_start=5 / epochs,
36
+ div_factor=100,
37
+ three_phase=False,
38
+ final_div_factor=100,
39
+ anneal_strategy="linear",
40
+ )
41
+ return scheduler
42
+
43
+
44
+ def show_random_images_for_each_class(train_data, num_images_per_class=16):
45
+ for c, cls in enumerate(train_data.classes):
46
+ rand_targets = random.sample(
47
+ [n for n, x in enumerate(train_data.targets) if x == c],
48
+ k=num_images_per_class,
49
+ )
50
+ show_img_grid(np.transpose(train_data.data[rand_targets], axes=(0, 3, 1, 2)))
51
+ plt.title(cls)
52
+
53
+
54
+ def show_img_grid(data):
55
+ try:
56
+ grid_img = torchvision.utils.make_grid(data.cpu().detach())
57
+ except:
58
+ data = torch.from_numpy(data)
59
+ grid_img = torchvision.utils.make_grid(data)
60
+
61
+ plt.figure(figsize=(10, 10))
62
+ plt.imshow(grid_img.permute(1, 2, 0))
63
+
64
+
65
+ def show_random_images(data_loader):
66
+ data, target = next(iter(data_loader))
67
+ show_img_grid(data)
68
+
69
+
70
+ def show_model_summary(model, batch_size):
71
+ summary(
72
+ model=model,
73
+ input_size=(batch_size, 3, 32, 32),
74
+ col_names=["input_size", "output_size", "num_params", "kernel_size"],
75
+ verbose=1,
76
+ )
77
+
78
+
79
+ def lossacc_plots(results):
80
+ plt.plot(results["epoch"], results["trainloss"])
81
+ plt.plot(results["epoch"], results["testloss"])
82
+ plt.legend(["Train Loss", "Validation Loss"])
83
+ plt.xlabel("Epochs")
84
+ plt.ylabel("Loss")
85
+ plt.title("Loss vs Epochs")
86
+ plt.show()
87
+
88
+ plt.plot(results["epoch"], results["trainacc"])
89
+ plt.plot(results["epoch"], results["testacc"])
90
+ plt.legend(["Train Acc", "Validation Acc"])
91
+ plt.xlabel("Epochs")
92
+ plt.ylabel("Accuracy")
93
+ plt.title("Accuracy vs Epochs")
94
+ plt.show()
95
+
96
+
97
+ def lr_plots(results, length):
98
+ plt.plot(range(length), results["lr"])
99
+ plt.xlabel("Epochs")
100
+ plt.ylabel("Learning Rate")
101
+ plt.title("Learning Rate vs Epochs")
102
+ plt.show()
103
+
104
+
105
+ def get_misclassified(model, testloader, device, mis_count=10):
106
+ misimgs, mistgts, mispreds = [], [], []
107
+ with torch.no_grad():
108
+ for data, target in testloader:
109
+ data, target = data.to(device), target.to(device)
110
+ output = model(data)
111
+ pred = output.argmax(dim=1, keepdim=True)
112
+ misclassified = torch.argwhere(pred.squeeze() != target).squeeze()
113
+ for idx in misclassified:
114
+ if len(misimgs) >= mis_count:
115
+ break
116
+ misimgs.append(data[idx])
117
+ mistgts.append(target[idx])
118
+ mispreds.append(pred[idx].squeeze())
119
+ return misimgs, mistgts, mispreds
120
+
121
+
122
+ # def plot_misclassified(misimgs, mistgts, mispreds, classes):
123
+ # fig, axes = plt.subplots(len(misimgs) // 2, 2)
124
+ # fig.tight_layout()
125
+ # for ax, img, tgt, pred in zip(axes.ravel(), misimgs, mistgts, mispreds):
126
+ # ax.imshow((img / img.max()).permute(1, 2, 0).cpu())
127
+ # ax.set_title(f"{classes[tgt]} | {classes[pred]}")
128
+ # ax.grid(False)
129
+ # ax.set_axis_off()
130
+ # plt.show()
131
+
132
+ def get_misclassified_data(model, device, test_loader, count):
133
+ """
134
+ Function to run the model on test set and return misclassified images
135
+ :param model: Network Architecture
136
+ :param device: CPU/GPU
137
+ :param test_loader: DataLoader for test set
138
+ """
139
+ # Prepare the model for evaluation i.e. drop the dropout layer
140
+ model.eval()
141
+
142
+ # List to store misclassified Images
143
+ misclassified_data = []
144
+
145
+ # Reset the gradients
146
+ with torch.no_grad():
147
+ # Extract images, labels in a batch
148
+ for data, target in test_loader:
149
+
150
+ # Migrate the data to the device
151
+ data, target = data.to(device), target.to(device)
152
+
153
+ # Extract single image, label from the batch
154
+ for image, label in zip(data, target):
155
+
156
+ # Add batch dimension to the image
157
+ image = image.unsqueeze(0)
158
+
159
+ # Get the model prediction on the image
160
+ output = model(image)
161
+
162
+ # Convert the output from one-hot encoding to a value
163
+ pred = output.argmax(dim=1, keepdim=True)
164
+
165
+ # If prediction is incorrect, append the data
166
+ if pred != label:
167
+ misclassified_data.append((image, label, pred))
168
+ if len(misclassified_data) >= count:
169
+ break
170
+
171
+ return misclassified_data[:count]
172
+
173
+ def plot_misclassified(data, classes, size=(10, 10), rows=2, cols=5, inv_normalize=None):
174
+ fig = plt.figure(figsize=size)
175
+ number_of_samples = len(data)
176
+ for i in range(number_of_samples):
177
+ plt.subplot(rows, cols, i + 1)
178
+ img = data[i][0].squeeze().to('cpu')
179
+ if inv_normalize is not None:
180
+ img = inv_normalize(img)
181
+ plt.imshow(np.transpose(img, (1, 2, 0)))
182
+ plt.title(f"Label: {classes[data[i][1].item()]} \n Prediction: {classes[data[i][2].item()]}")
183
+ plt.xticks([])
184
+ plt.yticks([])
185
+
utils/config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import toml
2
+ from pydantic import BaseModel
3
+
4
+ TOML_PATH = "config.toml"
5
+
6
+
7
+ class Data(BaseModel):
8
+ batch_size: int = 512
9
+ shuffle: bool = True
10
+ num_workers: int = 4
11
+
12
+
13
+ class LRFinder(BaseModel):
14
+ numiter: int = 600
15
+ endlr: float = 10
16
+ startlr: float = 1e-2
17
+
18
+
19
+ class Training(BaseModel):
20
+ epochs: int = 20
21
+ optimizer: str = "adam"
22
+ criterion: str = "crossentropy"
23
+ lr: float = 0.003
24
+ weight_decay: float = 1e-4
25
+ lrfinder: LRFinder
26
+
27
+
28
+ class Config(BaseModel):
29
+ data: Data
30
+ training: Training
31
+
32
+
33
+ with open(TOML_PATH) as f:
34
+ toml_config = toml.load(f)
35
+
36
+ config = Config(**toml_config)
utils/data.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import lightning as L
3
+ from torch.utils.data import DataLoader
4
+ from utils.transforms import train_transform, test_transform
5
+
6
+
7
+ class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
8
+ def __init__(self, root="~/data", train=True, download=True, transform=None):
9
+ super().__init__(root=root, train=train, download=download, transform=transform)
10
+
11
+ def __getitem__(self, index):
12
+ image, label = self.data[index], self.targets[index]
13
+ if self.transform is not None:
14
+ transformed = self.transform(image=image)
15
+ image = transformed["image"]
16
+
17
+ return image, label
18
+
19
+
20
+ class CIFARDataModule(L.LightningDataModule):
21
+ def __init__(
22
+ self, data_dir="data", batch_size=512, shuffle=True, num_workers=4
23
+ ) -> None:
24
+ super().__init__()
25
+ self.data_dir = data_dir
26
+ self.batch_size = batch_size
27
+ self.shuffle = shuffle
28
+ self.num_workers = num_workers
29
+
30
+ def prepare_data(self) -> None:
31
+ pass
32
+
33
+ def setup(self, stage=None):
34
+ self.train_dataset = Cifar10SearchDataset(
35
+ root=self.data_dir, train=True, transform=train_transform
36
+ )
37
+
38
+ self.val_dataset = Cifar10SearchDataset(
39
+ root=self.data_dir, train=False, transform=test_transform
40
+ )
41
+
42
+ self.test_dataset = Cifar10SearchDataset(
43
+ root=self.data_dir, train=False, transform=test_transform
44
+ )
45
+
46
+ def train_dataloader(self):
47
+ return DataLoader(
48
+ dataset=self.train_dataset,
49
+ batch_size=self.batch_size,
50
+ shuffle=self.shuffle,
51
+ num_workers=self.num_workers,
52
+ )
53
+
54
+ def val_dataloader(self):
55
+ return DataLoader(
56
+ dataset=self.val_dataset,
57
+ batch_size=self.batch_size,
58
+ shuffle=self.shuffle,
59
+ num_workers=self.num_workers,
60
+ )
61
+
62
+ def test_dataloader(self):
63
+ return DataLoader(
64
+ dataset=self.test_dataset,
65
+ batch_size=self.batch_size,
66
+ shuffle=self.shuffle,
67
+ num_workers=self.num_workers,
68
+ )
utils/gradcam.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pytorch_grad_cam import GradCAM
3
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
4
+ from pytorch_grad_cam.utils.image import show_cam_on_image
5
+
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+ def generate_gradcam(model, target_layers, images, labels, rgb_imgs):
10
+ results = []
11
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
12
+
13
+ for image, label, np_image in zip(images, labels, rgb_imgs):
14
+ targets = [ClassifierOutputTarget(label.item())]
15
+
16
+ # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
17
+ grayscale_cam = cam(
18
+ input_tensor=image.unsqueeze(0), targets=targets, aug_smooth=True
19
+ )
20
+
21
+ # In this example grayscale_cam has only one image in the batch:
22
+ grayscale_cam = grayscale_cam[0, :]
23
+ visualization = show_cam_on_image(
24
+ np_image / np_image.max(), grayscale_cam, use_rgb=True
25
+ )
26
+ results.append(visualization)
27
+ return results
28
+
29
+
30
+ def visualize_gradcam(misimgs, mistgts, mispreds, classes):
31
+ fig, axes = plt.subplots(len(misimgs) // 2, 2)
32
+ fig.tight_layout()
33
+ for ax, img, tgt, pred in zip(axes.ravel(), misimgs, mistgts, mispreds):
34
+ ax.imshow(img)
35
+ ax.set_title(f"{classes[tgt]} | {classes[pred]}")
36
+ ax.grid(False)
37
+ ax.set_axis_off()
38
+ plt.show()
39
+
40
+ def plot_gradcam(model, data, classes, target_layers, number_of_samples, inv_normalize=None, targets=None, transparency = 0.60, figsize=(10,10), rows=2, cols=5):
41
+
42
+ fig = plt.figure(figsize=figsize)
43
+
44
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
45
+ for i in range(number_of_samples):
46
+ plt.subplot(rows, cols, i + 1)
47
+ input_tensor = data[i][0]
48
+
49
+ # Get the activations of the layer for the images
50
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
51
+ grayscale_cam = grayscale_cam[0, :]
52
+
53
+ # Get back the original image
54
+ img = input_tensor.squeeze(0).to('cpu')
55
+ if inv_normalize is not None:
56
+ img = inv_normalize(img)
57
+ rgb_img = np.transpose(img, (1, 2, 0))
58
+ rgb_img = rgb_img.numpy()
59
+
60
+ # Mix the activations on the original image
61
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
62
+
63
+ # Display the images on the plot
64
+ plt.imshow(visualization)
65
+ plt.title(f"Label: {classes[data[i][1].item()]} \n Prediction: {classes[data[i][2].item()]}")
66
+ plt.xticks([])
67
+ plt.yticks([])
utils/training.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def train(
7
+ model,
8
+ device,
9
+ train_loader,
10
+ optimizer,
11
+ criterion,
12
+ scheduler,
13
+ L1=False,
14
+ l1_lambda=0.01,
15
+ ):
16
+ model.train()
17
+ pbar = tqdm(train_loader)
18
+
19
+ train_losses = []
20
+ train_acc = []
21
+ lrs = []
22
+
23
+ correct = 0
24
+ processed = 0
25
+ train_loss = 0
26
+
27
+ for batch_idx, (data, target) in enumerate(pbar):
28
+ data, target = data.to(device), target.to(device)
29
+ optimizer.zero_grad()
30
+ y_pred = model(data)
31
+
32
+ # Calculate loss
33
+ loss = criterion(y_pred, target)
34
+ if L1:
35
+ l1_loss = 0
36
+ for p in model.parameters():
37
+ l1_loss = l1_loss + p.abs().sum()
38
+ loss = loss + l1_lambda * l1_loss
39
+ else:
40
+ loss = loss
41
+
42
+ train_loss += loss.item()
43
+ train_losses.append(loss.item())
44
+
45
+ # Backpropagation
46
+ loss.backward()
47
+ optimizer.step()
48
+ scheduler.step()
49
+
50
+ # Update pbar-tqdm
51
+ pred = y_pred.argmax(
52
+ dim=1, keepdim=True
53
+ ) # get the index of the max log-probability
54
+ correct += pred.eq(target.view_as(pred)).sum().item()
55
+ processed += len(data)
56
+
57
+ pbar.set_description(
58
+ desc=f"Loss={loss.item():0.2f} Accuracy={100*correct/processed:0.2f}"
59
+ )
60
+ train_acc.append(100 * correct / processed)
61
+ lrs.append(scheduler.get_last_lr())
62
+
63
+ return train_losses, train_acc, lrs
64
+
65
+
66
+ def test(model, device, criterion, test_loader):
67
+ model.eval()
68
+ test_loss = 0
69
+ correct = 0
70
+ with torch.no_grad():
71
+ for data, target in test_loader:
72
+ data, target = data.to(device), target.to(device)
73
+ output = model(data)
74
+ test_loss += F.cross_entropy(output, target, reduction="sum").item()
75
+ pred = output.argmax(dim=1, keepdim=True)
76
+ correct += pred.eq(target.view_as(pred)).sum().item()
77
+
78
+ test_loss /= len(test_loader.dataset)
79
+
80
+ print(
81
+ "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
82
+ test_loss,
83
+ correct,
84
+ len(test_loader.dataset),
85
+ 100.0 * correct / len(test_loader.dataset),
86
+ )
87
+ )
88
+ test_acc = 100.0 * correct / len(test_loader.dataset)
89
+
90
+ return test_loss, test_acc
utils/transforms.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ from albumentations.pytorch import ToTensorV2
3
+
4
+ train_transform = A.Compose(
5
+ [
6
+ A.PadIfNeeded(min_height=40, min_width=40, always_apply=True),
7
+ A.RandomCrop(height=32, width=32, always_apply=True),
8
+ A.HorizontalFlip(),
9
+ A.CoarseDropout(
10
+ min_holes=1,
11
+ max_holes=1,
12
+ min_height=8,
13
+ min_width=8,
14
+ max_height=8,
15
+ max_width=8,
16
+ fill_value=[0.49139968*255, 0.48215827*255 ,0.44653124*255], # type: ignore
17
+ p=0.5,
18
+ ),
19
+ A.Normalize((0.49139968, 0.48215827, 0.44653124),
20
+ (0.24703233, 0.24348505, 0.26158768)),
21
+ ToTensorV2(),
22
+ ]
23
+ )
24
+
25
+ test_transform = A.Compose(
26
+ [
27
+ A.Normalize((0.49139968, 0.48215827, 0.44653124),
28
+ (0.24703233, 0.24348505, 0.26158768)),
29
+ ToTensorV2(),
30
+ ]
31
+ )