Cifar10 / resnet.py
PrarthanaTS's picture
Update resnet.py
c79dd40
# -*- coding: utf-8 -*-
"""
Created on Sat Aug 12 13:50:39 2023
@author: prarthana.ts
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from io import BytesIO
import numpy as np
import os
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch_lr_finder import LRFinder
import math
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from PIL import Image
import torch
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from albumentations import *
from albumentations.pytorch.transforms import ToTensorV2
import cv2
from torch.optim import Adam
from torch.optim.lr_scheduler import OneCycleLR
import matplotlib.gridspec as gridspec
import torchmetrics
import pandas as pd
import numpy as np
import seaborn as sns
from helper import *
from model import *
class CustomResNet(pl.LightningModule):
def __init__(self, config,dropout,train_transforms, test_transforms):
super(CustomResNet, self).__init__()
self.config = config
self.train_transforms = train_transforms
self.test_transforms = test_transforms
self.classes = config['classes']
self.prep = PrepBlock(dropout)
self.conv1 = ConvolutionBlock(64, 128)
self.R1 = ResidualBlock(128)
self.conv2 = ConvolutionBlock(128, 256)
self.conv3 = ConvolutionBlock(256, 512)
self.R2 = ResidualBlock(512)
self.maxpool = nn.MaxPool2d(kernel_size=(4, 4))
self.linear = nn.Linear(512, 10)
self.accuracy = torchmetrics.Accuracy(
task="multiclass", num_classes=config['num_classes']
)
self.confusion_matrix = torchmetrics.ConfusionMatrix(
task="multiclass", num_classes=config['num_classes']
)
def forward(self, x):
x = self.prep(x)
x = self.conv1(x)
x = self.R1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.R2(x)
x = self.maxpool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
x = x.view(-1,10)
return F.log_softmax(x,dim=1)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
pred = y_hat.argmax(dim=1, keepdim=True)
acc = pred.eq(y.view_as(pred)).float().mean()
self.log('train_losses', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('train_accuracy', acc, on_step=True, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
pred = y_hat.argmax(dim=1, keepdim=True)
acc = pred.eq(y.view_as(pred)).float().mean()
self.log('validation_losses', loss, prog_bar=True)
self.log('validation_accuracy', acc, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
pred = y_hat.argmax(dim=1, keepdim=True)
acc = pred.eq(y.view_as(pred)).float().mean()
self.confusion_matrix.update(y_hat, y)
self.log('test_losses', loss, prog_bar=True)
self.log('test_accuracy', acc, prog_bar=True)
return pred
def configure_optimizers(self):
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.01)
steps_per_epoch = 60000 // self.config['batch_size']
scheduler_dict = {
"scheduler": OneCycleLR(
optimizer,
max_lr= 0.0265608,
epochs=self.trainer.max_epochs,
steps_per_epoch=steps_per_epoch,
pct_start = 5/self.trainer.max_epochs
),
"interval": "step",
}
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
def prepare_data(self):
# download
CIFAR10(self.config['data_dir'], train=True, download=True)
CIFAR10(self.config['data_dir'], train=False, download=True)
def setup(self, stage=None):
if stage == 'fit' or stage is None:
self.train_dataset = CifarAlbumentations(CIFAR10(self.config['data_dir'], train=True, download = True),transforms = self.train_transforms)
self.val_dataset = CifarAlbumentations(CIFAR10(self.config['data_dir'], train=False, download = True), transforms=self.test_transforms)
if stage == 'test' or stage:
self.test_dataset = CifarAlbumentations(CIFAR10(self.config['data_dir'], train=False, download=True), transforms=self.test_transforms)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.config['batch_size'], num_workers=os.cpu_count())
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.config['batch_size'], num_workers=os.cpu_count())
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.config['batch_size'], num_workers=os.cpu_count())
def on_test_end(self) -> None:
## Confusion Matrix
cm = self.confusion_matrix.cpu().compute().numpy()
if True:
df_cm = pd.DataFrame(
cm / np.sum(cm, axis=1)[:, None],
index=[i for i in config['classes']],
columns=[i for i in config['classes']],
)
else:
df_cm = pd.DataFrame(
cm,
index=[i for i in config['classes']],
columns=[i for i in config['classes']],
)
plt.figure(figsize=(7, 5))
sns.heatmap(df_cm, annot=True, cmap="Blues", fmt=".3f", linewidths=0.5)
plt.tight_layout()
plt.ylabel("True label")
plt.xlabel("Predicted label")
plt.show()
def collect_misclassified_images(self, num_images):
misclassified_images = []
misclassified_true_labels = []
misclassified_predicted_labels = []
num_collected = 0
for batch in self.test_dataloader():
x, y = batch
y_hat = self.forward(x)
pred = y_hat.argmax(dim=1, keepdim=True)
misclassified_mask = pred.eq(y.view_as(pred)).squeeze()
misclassified_images.extend(x[~misclassified_mask].detach())
misclassified_true_labels.extend(y[~misclassified_mask].detach())
misclassified_predicted_labels.extend(pred[~misclassified_mask].detach())
num_collected += sum(~misclassified_mask)
if num_collected >= num_images:
break
return misclassified_images[:num_images], misclassified_true_labels[:num_images], misclassified_predicted_labels[:num_images], len(misclassified_images)
def normalize_image(self, img_tensor):
min_val = img_tensor.min()
max_val = img_tensor.max()
return (img_tensor - min_val) / (max_val - min_val)
def get_missed_gradcam_images(self, target_layer=-1, transparency=0.5, num_images=10):
misclassified_images, true_labels, predicted_labels, num_misclassified = self.collect_misclassified_images(num_images)
count = 0
k = 0
misclassified_images_converted = list()
gradcam_images = list()
if target_layer == -2:
target_layer = [self.conv2]
else:
target_layer = [self.conv3]
dataset_mean, dataset_std = np.array([0.49139968, 0.48215841, 0.44653091]), np.array([0.24703223, 0.24348513, 0.26158784])
grad_cam = GradCAM(model=self.cpu(), target_layers=target_layer, use_cuda=False) # Move model to CPU
for i in range(0, num_images):
img_converted = misclassified_images[i].cpu().numpy().transpose(1, 2, 0) # Convert tensor to numpy and transpose to (H, W, C)
img_converted = dataset_std * img_converted + dataset_mean
img_converted = np.clip(img_converted, 0, 1)
misclassified_images_converted.append(img_converted)
targets = [ClassifierOutputTarget(true_labels[i])]
grayscale_cam = grad_cam(input_tensor=misclassified_images[i].unsqueeze(0).cpu(), targets=targets) # Move input to CPU
grayscale_cam = grayscale_cam[0, :]
output = show_cam_on_image(img_converted, grayscale_cam, use_rgb=True, image_weight=transparency)
gradcam_images.append(output)
return gradcam_images
def create_layout(self, num_images, use_gradcam):
num_cols = 3 if use_gradcam else 2
fig = plt.figure(figsize=(12, 5 * num_images))
gs = gridspec.GridSpec(num_images, num_cols, figure=fig, width_ratios=[0.3, 1, 1] if use_gradcam else [0.5, 1])
return fig, gs
def plot_missed(self, fig, gs, i, img,label_text ,use_gradcam=False, gradcam_img=None):
ax_img = fig.add_subplot(gs[i, 1])
ax_img.imshow(img)
ax_img.set_title(label_text, fontsize=12)
ax_img.axis("off")
ax_img.spines['top'].set_visible(False)
ax_img.spines['bottom'].set_visible(False)
ax_img.spines['left'].set_visible(False)
ax_img.spines['right'].set_visible(False)
if use_gradcam:
ax_gradcam = fig.add_subplot(gs[i, 2])
ax_gradcam.imshow(gradcam_img)
ax_gradcam.set_title("GradCAM Image", fontsize=12)
ax_gradcam.axis("off")
ax_gradcam.spines['top'].set_visible(False)
ax_gradcam.spines['bottom'].set_visible(False)
ax_gradcam.spines['left'].set_visible(False)
ax_gradcam.spines['right'].set_visible(False)
def show_misclassified_images(self, num_images=10, use_gradcam=False, gradcam_layer=-1, transparency=0.5):
misclassified_images, true_labels, predicted_labels, num_misclassified = self.collect_misclassified_images(num_images)
fig, gs = self.create_layout(num_images, use_gradcam)
if use_gradcam:
grad_cam_images = self.get_missed_gradcam_images(target_layer=gradcam_layer, transparency=transparency, num_images=num_images)
for i in range(num_images):
img = misclassified_images[i].numpy().transpose((1, 2, 0)) # Convert tensor to numpy and transpose to (H, W, C)
img = self.normalize_image(img) # Normalize the image
# Show true label and predicted label on the left, and images on the right
label_text = f"True Label: {self.classes[true_labels[i]]}\nPredicted Label: {self.classes[predicted_labels[i]]}"
self.plot_missed(fig, gs, i, img,label_text, use_gradcam, grad_cam_images[i] if use_gradcam else None)
plt.tight_layout()
return fig