Spaces:
Runtime error
Runtime error
| import torch | |
| from pytorch_grad_cam import GradCAM | |
| from torch import Tensor | |
| from transformers import ViTForImageClassification | |
| def grad_cam(images: Tensor, vit: ViTForImageClassification, use_cuda: bool = False) -> Tensor: | |
| """Performs the Grad-CAM method on a batch of images (https://arxiv.org/pdf/1610.02391.pdf).""" | |
| # Wrap the ViT model to be compatible with GradCAM | |
| vit = ViTWrapper(vit) | |
| vit.eval() | |
| # Create GradCAM object | |
| cam = GradCAM( | |
| model=vit, | |
| target_layers=[vit.target_layer], | |
| reshape_transform=_reshape_transform, | |
| use_cuda=use_cuda, | |
| ) | |
| # Compute GradCAM masks | |
| grayscale_cam = cam( | |
| input_tensor=images, | |
| targets=None, | |
| eigen_smooth=True, | |
| aug_smooth=True, | |
| ) | |
| return torch.from_numpy(grayscale_cam) | |
| def _reshape_transform(tensor, height=14, width=14): | |
| result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2)) | |
| # Bring the channels to the first dimension | |
| result = result.transpose(2, 3).transpose(1, 2) | |
| return result | |
| class ViTWrapper(torch.nn.Module): | |
| """ViT Wrapper to use with Grad-CAM.""" | |
| def __init__(self, vit: ViTForImageClassification): | |
| super().__init__() | |
| self.vit = vit | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.vit(x).logits | |
| def target_layer(self): | |
| return self.vit.vit.encoder.layer[-2].layernorm_after | |