import cv2 import numpy as np from PIL import Image import torch import torch.nn.functional as F def get_resnet_gradcam(image_path, predictor, output_path): model = predictor.model device = predictor.device model.eval() features, gradients = [], [] def forward_hook(module, input, output): features.append(output) def backward_hook(module, grad_in, grad_out): gradients.append(grad_out[0]) target_layer = model.model.layer4[-1] handle_fw = target_layer.register_forward_hook(forward_hook) handle_bw = target_layer.register_full_backward_hook(backward_hook) original_img = Image.open(image_path).convert("RGB") input_tensor = predictor.test_transforms(original_img).unsqueeze(0).to(device) model.zero_grad() output = model(input_tensor) pred_class_idx = output.argmax(dim=1).item() score = output[0, pred_class_idx] score.backward() handle_fw.remove() handle_bw.remove() acts = features[0].cpu().data.numpy()[0] grads = gradients[0].cpu().data.numpy()[0] weights = np.mean(grads, axis=(1, 2)) cam = np.zeros(acts.shape[1:], dtype=np.float32) for i, w in enumerate(weights): cam += w * acts[i] cam = np.maximum(cam, 0) cam = cv2.resize(cam, (original_img.width, original_img.height)) cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8) heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET) original_np = np.array(original_img) overlay = cv2.addWeighted(cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR), 0.6, heatmap, 0.4, 0) cv2.imwrite(output_path, overlay) return True def get_fusion_gradcam(image_path, predictor, output_path): model = predictor.model device = predictor.device model.eval() target_layer = model.eff_features[-1] activation = None def forward_hook(module, inp, out): nonlocal activation activation = out activation.retain_grad() handle = target_layer.register_forward_hook(forward_hook) original_img = Image.open(image_path).convert("RGB") pixel_eff = predictor.eff_normalize(original_img).unsqueeze(0).to(device) inputs_cnx = predictor.convnext_processor(images=original_img, return_tensors="pt") pixel_cnx = inputs_cnx["pixel_values"].to(device) if next(model.parameters()).dtype == torch.float16: pixel_eff = pixel_eff.half() pixel_cnx = pixel_cnx.half() model.zero_grad() output = model(pixel_eff, pixel_cnx) pred_class_idx = output.argmax(dim=1).item() score = output[0, pred_class_idx] score.backward() handle.remove() if activation is None or activation.grad is None: raise RuntimeError("Gradients could not be extracted. Ensure requires_grad=True is properly set.") acts = activation[0].detach().float() grads = activation.grad[0].detach().float() weights = grads.mean(dim=(1, 2), keepdim=True) cam = torch.sum(weights * acts, dim=0) cam = F.relu(cam) cam = cam.cpu().numpy() if cam.max() > cam.min(): cam = (cam - cam.min()) / (cam.max() - cam.min()) else: cam = np.zeros_like(cam) cam = np.uint8(255 * cam) cam_resized = cv2.resize(cam, (original_img.width, original_img.height), interpolation=cv2.INTER_LINEAR) heatmap = cv2.applyColorMap(cam_resized, cv2.COLORMAP_JET) original_np = np.array(original_img) original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR) overlay = cv2.addWeighted(original_bgr, 0.5, heatmap, 0.6, 0) cv2.imwrite(output_path, overlay) return True