import torch import numpy as np class GradCAM: def __init__(self, model, target_layer): self.model = model self.gradients = None self.activations = None target_layer.register_forward_hook(self._fwd) target_layer.register_backward_hook(self._bwd) def _fwd(self, m, i, o): self.activations = o def _bwd(self, m, gi, go): self.gradients = go[0] def generate(self, img_tensor, class_idx): out = self.model(img_tensor.unsqueeze(0)) score = out["class"][:, class_idx].sum() self.model.zero_grad() score.backward() w = self.gradients.mean(dim=(2,3), keepdim=True) cam = (w * self.activations).sum(dim=1) cam = torch.relu(cam) cam = cam / (cam.max() + 1e-8) return cam.detach().cpu().numpy()[0]