glaunet-screening / visualize.py
rallou's picture
Add GradCAM visualizations
68b4742
import cv2
import numpy as np
import torch
from pipeline import preprocess_disc_crop, preprocess_for_efficientnet, TRAIN_MEAN, TRAIN_STD
def make_visualization_images(image_bgr: np.ndarray, result: dict, eff_model, device) -> dict:
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
img_yolo = image_rgb.copy()
x1, y1, x2, y2 = [int(v) for v in result["yolo_bbox"]]
cv2.rectangle(img_yolo, (x1, y1), (x2, y2), (255, 215, 0), 3)
crop_bgr = result["crop"]
crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
tensor_proc = preprocess_disc_crop(crop_bgr, img_size=512)
img_proc = tensor_proc.permute(1, 2, 0).numpy()
img_proc = img_proc * TRAIN_STD + TRAIN_MEAN
img_proc = np.clip(img_proc * 255, 0, 255).astype(np.uint8)
pred_map = result["pred_map"]
seg_rgb = np.zeros((*pred_map.shape, 3), dtype=np.uint8)
seg_rgb[pred_map == 1] = [0, 200, 0]
seg_rgb[pred_map == 2] = [200, 0, 0]
crop_512 = cv2.resize(crop_rgb, (512, 512), interpolation=cv2.INTER_LINEAR)
overlay = crop_512.astype(np.float32).copy()
overlay[result["od_mask"]] = (
overlay[result["od_mask"]] * 0.5 + np.array([0, 200, 0]) * 0.5
)
overlay[result["oc_mask"]] = (
overlay[result["oc_mask"]] * 0.5 + np.array([200, 0, 0]) * 0.5
)
overlay_rgb = overlay.astype(np.uint8)
mean_eff = np.array([0.485, 0.456, 0.406])
std_eff = np.array([0.229, 0.224, 0.225])
tensor_eff = preprocess_for_efficientnet(image_bgr, img_size=300)
img_eff_display = tensor_eff.permute(1, 2, 0).numpy()
img_eff_display = img_eff_display * std_eff + mean_eff
img_eff_display = np.clip(img_eff_display, 0, 1)
cam, _ = compute_gradcam(eff_model, tensor_eff, device)
overlay_proc = overlay_gradcam(img_eff_display, cam, alpha=0.4)
original_300 = cv2.resize(image_rgb, (300, 300)) / 255.0
overlay_orig = overlay_gradcam(original_300, cam, alpha=0.4)
return {
"img_yolo": img_yolo,
"crop_rgb": crop_rgb,
"img_proc": img_proc,
"seg_rgb": seg_rgb,
"overlay_rgb": overlay_rgb,
"img_eff_input": image_rgb,
"img_eff_proc": (img_eff_display * 255).astype(np.uint8),
"gradcam_proc": (overlay_proc * 255).astype(np.uint8),
"gradcam_orig": (overlay_orig * 255).astype(np.uint8),
}
def compute_gradcam(eff_model, tensor_eff, device):
gradients = {}
activations = {}
def forward_hook(module, input, output):
activations["feat"] = output.detach()
def backward_hook(module, grad_input, grad_output):
gradients["feat"] = grad_output[0].detach()
target_layer = eff_model.backbone.features[-1]
fh = target_layer.register_forward_hook(forward_hook)
bh = target_layer.register_full_backward_hook(backward_hook)
input_tensor = tensor_eff.unsqueeze(0).to(device).requires_grad_(True)
eff_model.eval()
output = eff_model(input_tensor)
prob = torch.sigmoid(output).item()
eff_model.zero_grad()
output.backward()
fh.remove()
bh.remove()
weights = gradients["feat"].mean(dim=[2, 3], keepdim=True)
cam = (weights * activations["feat"]).sum(dim=1, keepdim=True)
cam = torch.relu(cam).squeeze().cpu().numpy()
cam = cam - cam.min()
if cam.max() > 0:
cam = cam / cam.max()
return cam, prob
def overlay_gradcam(original_img_np: np.ndarray, cam: np.ndarray, alpha: float = 0.4) -> np.ndarray:
cam_resized = cv2.resize(cam, (original_img_np.shape[1], original_img_np.shape[0]))
heatmap = cv2.applyColorMap((cam_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0
overlay = (1 - alpha) * original_img_np + alpha * heatmap
return np.clip(overlay, 0, 1)