Spaces:
Sleeping
Sleeping
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from torchvision import transforms | |
| from torch.nn.functional import interpolate | |
| class GradCAM: | |
| def __init__(self, model, target_layer): | |
| self.model = model | |
| self.target_layer = target_layer | |
| self.activations = [] | |
| self.gradients = [] | |
| # Register hooks | |
| target_layer.register_forward_hook(self.save_activations) | |
| target_layer.register_backward_hook(self.save_gradients) | |
| def save_activations(self, module, input, output): | |
| self.activations.append(output.detach()) | |
| def save_gradients(self, module, grad_input, grad_output): | |
| self.gradients.append(grad_output[0].detach()) | |
| def forward(self, input_tensor): | |
| return self.model(input_tensor) | |
| def generate(self, input_tensor, target_class): | |
| # Forward pass | |
| output = self.forward(input_tensor) | |
| # Backward pass for specific class | |
| self.model.zero_grad() | |
| loss = output[:, target_class].mean() | |
| loss.backward(retain_graph=True) | |
| # Get activations and gradients | |
| activations = self.activations[0].cpu().data.numpy()[0] | |
| gradients = self.gradients[0].cpu().data.numpy()[0] | |
| # Compute weights | |
| weights = np.mean(gradients, axis=(1, 2)) | |
| # Create CAM | |
| cam = np.zeros(activations.shape[1:], dtype=np.float32) | |
| for i, w in enumerate(weights): | |
| cam += w * activations[i, :, :] | |
| # Post-process CAM | |
| cam = np.maximum(cam, 0) | |
| cam = interpolate(torch.from_numpy(cam[None, None]), | |
| size=(224, 224), mode='bilinear').numpy() | |
| cam = cam.squeeze() | |
| if cam.max() != 0: | |
| cam /= cam.max() | |
| return cam | |
| def generate_gradcam(image, target_class, model, target_layer): | |
| # Preprocess image | |
| preprocess = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| if not isinstance(image, torch.Tensor): | |
| image = preprocess(image) | |
| image_preprocessed = image.unsqueeze(0).requires_grad_(True).to(device) | |
| # Initialize Grad-CAM | |
| gradcam = GradCAM(model, target_layer) | |
| # Generate CAM | |
| image = image.to(device) | |
| cam = gradcam.generate(image_preprocessed, target_class) | |
| return cam | |