| import torch, torch.nn.functional as F |
| from torchvision import transforms |
| from PIL import Image |
| import numpy as np, io, base64 |
|
|
| def _normalize_cam(cam): |
| cam = cam - cam.min() |
| cam = cam / (cam.max() + 1e-6) |
| return cam |
|
|
| def grad_cam(model, img: Image.Image, img_size=224, target_layer=None, device="cpu"): |
| model.eval() |
| tfms = transforms.Compose([ |
| transforms.Resize(int(img_size*1.15)), |
| transforms.CenterCrop(img_size), |
| transforms.ToTensor() |
| ]) |
| x = tfms(img).unsqueeze(0).to(device) |
| x.requires_grad_(True) |
|
|
| if target_layer is None: |
| target_layer = model.features[-1][0] |
|
|
| activations, grads = [], [] |
| def fwd_hook(_, __, out): activations.append(out) |
| def bwd_hook(_, gin, gout): grads.append(gout[0]) |
|
|
| h1 = target_layer.register_forward_hook(fwd_hook) |
| h2 = target_layer.register_full_backward_hook(bwd_hook) |
|
|
| logits = model(x) |
| pred = int(logits.argmax(dim=1).item()) |
| score = logits[0, pred] |
| model.zero_grad(set_to_none=True) |
| score.backward() |
|
|
| A = activations[-1] |
| if A.dim() == 4: A = A[0] |
| elif A.dim() == 3: pass |
| else: A = A.mean(dim=0) |
|
|
| G = grads[-1] |
| if G.dim() == 4: G = G[0] |
|
|
| if G.shape[0] == A.shape[0]: |
| weights = G.mean(dim=(1,2)) |
| cam = (weights[:, None, None] * A).sum(0) |
| else: |
| cam = A.mean(dim=0) |
|
|
| cam = F.relu(cam)[None, None, ...] |
| cam = F.interpolate(cam, size=(img_size, img_size), mode='bilinear', align_corners=False)[0,0] |
| cam = _normalize_cam(cam).detach().cpu().numpy() |
|
|
| img_np = (x[0].detach().cpu().permute(1,2,0).numpy()) |
| img_np = (img_np - img_np.min())/(img_np.max()-img_np.min()+1e-6) |
|
|
| import matplotlib.cm as cm |
| heat = cm.jet(cam)[..., :3] |
| overlay = 0.6*img_np + 0.4*heat |
| overlay = np.clip(overlay, 0, 1) |
|
|
| probs = torch.softmax(logits, dim=1)[0].detach().cpu().numpy() |
|
|
| h1.remove(); h2.remove() |
| return {"pred": pred, "probs": probs, "overlay": overlay, "input_image": img_np, "cam": cam} |
|
|