Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Sat Aug 12 13:50:39 2023 | |
| @author: prarthana.ts | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchsummary import summary | |
| from io import BytesIO | |
| import numpy as np | |
| import os | |
| from pytorch_lightning import LightningModule, Trainer | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torch.utils.data import DataLoader, random_split | |
| from torchmetrics import Accuracy | |
| from torchvision import transforms | |
| from torchvision.datasets import CIFAR10 | |
| from torch_lr_finder import LRFinder | |
| import math | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import DataLoader, random_split | |
| import torchvision.transforms as transforms | |
| import torchvision.datasets as datasets | |
| import pytorch_lightning as pl | |
| import matplotlib.pyplot as plt | |
| from albumentations import * | |
| from albumentations.pytorch.transforms import ToTensorV2 | |
| import cv2 | |
| from torch.optim import Adam | |
| from torch.optim.lr_scheduler import OneCycleLR | |
| import matplotlib.gridspec as gridspec | |
| import torchmetrics | |
| import pandas as pd | |
| import numpy as np | |
| import seaborn as sns | |
| from helper import * | |
| from model import * | |
| class CustomResNet(pl.LightningModule): | |
| def __init__(self, config,dropout,train_transforms, test_transforms): | |
| super(CustomResNet, self).__init__() | |
| self.config = config | |
| self.train_transforms = train_transforms | |
| self.test_transforms = test_transforms | |
| self.classes = config['classes'] | |
| self.prep = PrepBlock(dropout) | |
| self.conv1 = ConvolutionBlock(64, 128) | |
| self.R1 = ResidualBlock(128) | |
| self.conv2 = ConvolutionBlock(128, 256) | |
| self.conv3 = ConvolutionBlock(256, 512) | |
| self.R2 = ResidualBlock(512) | |
| self.maxpool = nn.MaxPool2d(kernel_size=(4, 4)) | |
| self.linear = nn.Linear(512, 10) | |
| self.accuracy = torchmetrics.Accuracy( | |
| task="multiclass", num_classes=config['num_classes'] | |
| ) | |
| self.confusion_matrix = torchmetrics.ConfusionMatrix( | |
| task="multiclass", num_classes=config['num_classes'] | |
| ) | |
| def forward(self, x): | |
| x = self.prep(x) | |
| x = self.conv1(x) | |
| x = self.R1(x) | |
| x = self.conv2(x) | |
| x = self.conv3(x) | |
| x = self.R2(x) | |
| x = self.maxpool(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.linear(x) | |
| x = x.view(-1,10) | |
| return F.log_softmax(x,dim=1) | |
| return x | |
| def training_step(self, batch, batch_idx): | |
| x, y = batch | |
| y_hat = self.forward(x) | |
| loss = F.cross_entropy(y_hat, y) | |
| pred = y_hat.argmax(dim=1, keepdim=True) | |
| acc = pred.eq(y.view_as(pred)).float().mean() | |
| self.log('train_losses', loss, on_step=True, on_epoch=True, prog_bar=True) | |
| self.log('train_accuracy', acc, on_step=True, on_epoch=True, prog_bar=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| x, y = batch | |
| y_hat = self.forward(x) | |
| loss = F.cross_entropy(y_hat, y) | |
| pred = y_hat.argmax(dim=1, keepdim=True) | |
| acc = pred.eq(y.view_as(pred)).float().mean() | |
| self.log('validation_losses', loss, prog_bar=True) | |
| self.log('validation_accuracy', acc, prog_bar=True) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| x, y = batch | |
| y_hat = self.forward(x) | |
| loss = F.cross_entropy(y_hat, y) | |
| pred = y_hat.argmax(dim=1, keepdim=True) | |
| acc = pred.eq(y.view_as(pred)).float().mean() | |
| self.confusion_matrix.update(y_hat, y) | |
| self.log('test_losses', loss, prog_bar=True) | |
| self.log('test_accuracy', acc, prog_bar=True) | |
| return pred | |
| def configure_optimizers(self): | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = Adam(model.parameters(), lr=0.01) | |
| steps_per_epoch = 60000 // self.config['batch_size'] | |
| scheduler_dict = { | |
| "scheduler": OneCycleLR( | |
| optimizer, | |
| max_lr= 0.0265608, | |
| epochs=self.trainer.max_epochs, | |
| steps_per_epoch=steps_per_epoch, | |
| pct_start = 5/self.trainer.max_epochs | |
| ), | |
| "interval": "step", | |
| } | |
| return {"optimizer": optimizer, "lr_scheduler": scheduler_dict} | |
| def prepare_data(self): | |
| # download | |
| CIFAR10(self.config['data_dir'], train=True, download=True) | |
| CIFAR10(self.config['data_dir'], train=False, download=True) | |
| def setup(self, stage=None): | |
| if stage == 'fit' or stage is None: | |
| self.train_dataset = CifarAlbumentations(CIFAR10(self.config['data_dir'], train=True, download = True),transforms = self.train_transforms) | |
| self.val_dataset = CifarAlbumentations(CIFAR10(self.config['data_dir'], train=False, download = True), transforms=self.test_transforms) | |
| if stage == 'test' or stage: | |
| self.test_dataset = CifarAlbumentations(CIFAR10(self.config['data_dir'], train=False, download=True), transforms=self.test_transforms) | |
| def train_dataloader(self): | |
| return DataLoader(self.train_dataset, batch_size=self.config['batch_size'], num_workers=os.cpu_count()) | |
| def val_dataloader(self): | |
| return DataLoader(self.val_dataset, batch_size=self.config['batch_size'], num_workers=os.cpu_count()) | |
| def test_dataloader(self): | |
| return DataLoader(self.test_dataset, batch_size=self.config['batch_size'], num_workers=os.cpu_count()) | |
| def on_test_end(self) -> None: | |
| ## Confusion Matrix | |
| cm = self.confusion_matrix.cpu().compute().numpy() | |
| if True: | |
| df_cm = pd.DataFrame( | |
| cm / np.sum(cm, axis=1)[:, None], | |
| index=[i for i in config['classes']], | |
| columns=[i for i in config['classes']], | |
| ) | |
| else: | |
| df_cm = pd.DataFrame( | |
| cm, | |
| index=[i for i in config['classes']], | |
| columns=[i for i in config['classes']], | |
| ) | |
| plt.figure(figsize=(7, 5)) | |
| sns.heatmap(df_cm, annot=True, cmap="Blues", fmt=".3f", linewidths=0.5) | |
| plt.tight_layout() | |
| plt.ylabel("True label") | |
| plt.xlabel("Predicted label") | |
| plt.show() | |
| def collect_misclassified_images(self, num_images): | |
| misclassified_images = [] | |
| misclassified_true_labels = [] | |
| misclassified_predicted_labels = [] | |
| num_collected = 0 | |
| for batch in self.test_dataloader(): | |
| x, y = batch | |
| y_hat = self.forward(x) | |
| pred = y_hat.argmax(dim=1, keepdim=True) | |
| misclassified_mask = pred.eq(y.view_as(pred)).squeeze() | |
| misclassified_images.extend(x[~misclassified_mask].detach()) | |
| misclassified_true_labels.extend(y[~misclassified_mask].detach()) | |
| misclassified_predicted_labels.extend(pred[~misclassified_mask].detach()) | |
| num_collected += sum(~misclassified_mask) | |
| if num_collected >= num_images: | |
| break | |
| return misclassified_images[:num_images], misclassified_true_labels[:num_images], misclassified_predicted_labels[:num_images], len(misclassified_images) | |
| def normalize_image(self, img_tensor): | |
| min_val = img_tensor.min() | |
| max_val = img_tensor.max() | |
| return (img_tensor - min_val) / (max_val - min_val) | |
| def get_missed_gradcam_images(self, target_layer=-1, transparency=0.5, num_images=10): | |
| misclassified_images, true_labels, predicted_labels, num_misclassified = self.collect_misclassified_images(num_images) | |
| count = 0 | |
| k = 0 | |
| misclassified_images_converted = list() | |
| gradcam_images = list() | |
| if target_layer == -2: | |
| target_layer = [self.conv2] | |
| else: | |
| target_layer = [self.conv3] | |
| dataset_mean, dataset_std = np.array([0.49139968, 0.48215841, 0.44653091]), np.array([0.24703223, 0.24348513, 0.26158784]) | |
| grad_cam = GradCAM(model=self.cpu(), target_layers=target_layer, use_cuda=False) # Move model to CPU | |
| for i in range(0, num_images): | |
| img_converted = misclassified_images[i].cpu().numpy().transpose(1, 2, 0) # Convert tensor to numpy and transpose to (H, W, C) | |
| img_converted = dataset_std * img_converted + dataset_mean | |
| img_converted = np.clip(img_converted, 0, 1) | |
| misclassified_images_converted.append(img_converted) | |
| targets = [ClassifierOutputTarget(true_labels[i])] | |
| grayscale_cam = grad_cam(input_tensor=misclassified_images[i].unsqueeze(0).cpu(), targets=targets) # Move input to CPU | |
| grayscale_cam = grayscale_cam[0, :] | |
| output = show_cam_on_image(img_converted, grayscale_cam, use_rgb=True, image_weight=transparency) | |
| gradcam_images.append(output) | |
| return gradcam_images | |
| def create_layout(self, num_images, use_gradcam): | |
| num_cols = 3 if use_gradcam else 2 | |
| fig = plt.figure(figsize=(12, 5 * num_images)) | |
| gs = gridspec.GridSpec(num_images, num_cols, figure=fig, width_ratios=[0.3, 1, 1] if use_gradcam else [0.5, 1]) | |
| return fig, gs | |
| def plot_missed(self, fig, gs, i, img,label_text ,use_gradcam=False, gradcam_img=None): | |
| ax_img = fig.add_subplot(gs[i, 1]) | |
| ax_img.imshow(img) | |
| ax_img.set_title(label_text, fontsize=12) | |
| ax_img.axis("off") | |
| ax_img.spines['top'].set_visible(False) | |
| ax_img.spines['bottom'].set_visible(False) | |
| ax_img.spines['left'].set_visible(False) | |
| ax_img.spines['right'].set_visible(False) | |
| if use_gradcam: | |
| ax_gradcam = fig.add_subplot(gs[i, 2]) | |
| ax_gradcam.imshow(gradcam_img) | |
| ax_gradcam.set_title("GradCAM Image", fontsize=12) | |
| ax_gradcam.axis("off") | |
| ax_gradcam.spines['top'].set_visible(False) | |
| ax_gradcam.spines['bottom'].set_visible(False) | |
| ax_gradcam.spines['left'].set_visible(False) | |
| ax_gradcam.spines['right'].set_visible(False) | |
| def show_misclassified_images(self, num_images=10, use_gradcam=False, gradcam_layer=-1, transparency=0.5): | |
| misclassified_images, true_labels, predicted_labels, num_misclassified = self.collect_misclassified_images(num_images) | |
| fig, gs = self.create_layout(num_images, use_gradcam) | |
| if use_gradcam: | |
| grad_cam_images = self.get_missed_gradcam_images(target_layer=gradcam_layer, transparency=transparency, num_images=num_images) | |
| for i in range(num_images): | |
| img = misclassified_images[i].numpy().transpose((1, 2, 0)) # Convert tensor to numpy and transpose to (H, W, C) | |
| img = self.normalize_image(img) # Normalize the image | |
| # Show true label and predicted label on the left, and images on the right | |
| label_text = f"True Label: {self.classes[true_labels[i]]}\nPredicted Label: {self.classes[predicted_labels[i]]}" | |
| self.plot_missed(fig, gs, i, img,label_text, use_gradcam, grad_cam_images[i] if use_gradcam else None) | |
| plt.tight_layout() | |
| return fig |