File size: 5,933 Bytes
1264815 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | from __future__ import annotations
from pathlib import Path
from typing import Any
import cv2
import numpy as np
import pandas as pd
import torch
from PIL import Image, ImageOps
from torch import nn
from .augmentations import IMAGENET_MEAN, IMAGENET_STD, build_eval_transform
from .compare_models import load_best_model_record
from .dl_models import create_model, load_torch_checkpoint
from .paths import ensure_dir
from .preprocessing import load_pil_image
from .utils import get_logger
LOGGER = get_logger(__name__)
def find_last_conv_layer(model: nn.Module) -> nn.Module | None:
last_conv: nn.Module | None = None
for module in model.modules():
if isinstance(module, nn.Conv2d):
last_conv = module
return last_conv
def denormalize_tensor(image_tensor: torch.Tensor) -> np.ndarray:
mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
image = image_tensor.detach().cpu() * std + mean
image = image.clamp(0, 1).permute(1, 2, 0).numpy()
return (image * 255).astype(np.uint8)
def gradcam_overlay(
model: nn.Module,
image: str | Path | Image.Image,
config: dict[str, Any],
output_path: str | Path | None = None,
target_class: int | None = None,
device: torch.device | None = None,
) -> Image.Image:
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()
target_layer = find_last_conv_layer(model)
if target_layer is None:
raise ValueError("Grad-CAM is only available for CNN models with Conv2d layers.")
activation: dict[str, torch.Tensor] = {}
def forward_hook(_: nn.Module, __: tuple[torch.Tensor, ...], output: torch.Tensor) -> torch.Tensor:
cloned = output.clone()
cloned.retain_grad()
activation["value"] = cloned
return cloned
handle_fwd = target_layer.register_forward_hook(forward_hook)
try:
pil = load_pil_image(image, mode="RGB")
transform = build_eval_transform(config)
tensor = transform(pil).unsqueeze(0).to(device)
tensor.requires_grad_(True)
logits = model(tensor)
class_idx = int(target_class if target_class is not None else torch.argmax(logits, dim=1).item())
model.zero_grad(set_to_none=True)
logits[:, class_idx].sum().backward()
acts_with_grad = activation.get("value")
if acts_with_grad is None or acts_with_grad.grad is None:
raise RuntimeError("Grad-CAM hook did not capture activations/gradients.")
acts = acts_with_grad.detach()[0]
grads = acts_with_grad.grad.detach()[0]
weights = grads.mean(dim=(1, 2), keepdim=True)
cam = (weights * acts).sum(dim=0)
cam = torch.relu(cam)
cam -= cam.min()
cam /= cam.max().clamp(min=1e-8)
cam_np = cam.detach().cpu().numpy()
base = denormalize_tensor(tensor[0])
heatmap = cv2.resize(cam_np, (base.shape[1], base.shape[0]))
heatmap_uint8 = np.uint8(255 * heatmap)
color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB)
overlay = np.uint8(0.55 * base + 0.45 * color)
out = Image.fromarray(overlay)
if output_path:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
out.save(output_path)
return out
finally:
handle_fwd.remove()
def save_gradcam_examples_for_best(
config: dict[str, Any],
splits_df: pd.DataFrame,
leaderboard_df: pd.DataFrame | None = None,
) -> list[Path]:
record = (
leaderboard_df.iloc[0].to_dict()
if leaderboard_df is not None and not leaderboard_df.empty
else load_best_model_record(config)
)
if record.get("model_type") != "deep_learning":
LOGGER.info("Best model is not a deep CNN; Grad-CAM generation skipped.")
return []
model_path = Path(record["model_path"])
checkpoint = load_torch_checkpoint(model_path, map_location="cpu")
family = checkpoint.get("family", "cnn")
if family != "cnn":
LOGGER.info("Best deep model is %s; Grad-CAM is skipped for non-CNN models.", family)
return []
model = create_model(checkpoint["model_key"], checkpoint.get("config", config), pretrained=False)
model.load_state_dict(checkpoint["state_dict"])
test_df = splits_df[splits_df["split"] == "test"].copy().reset_index(drop=True)
if test_df.empty:
return []
count = int(config.get("explainability", {}).get("max_images", 8))
sample = test_df.groupby("label", group_keys=False).head(max(1, count // 2)).head(count)
out_dir = ensure_dir(Path(config["paths"]["output_dir"]) / "plots" / "gradcam")
saved: list[Path] = []
for idx, row in enumerate(sample.itertuples(index=False), start=1):
output_path = out_dir / f"{record['model_name']}_gradcam_{idx:02d}.png"
try:
gradcam_overlay(model, row.filepath, checkpoint.get("config", config), output_path, target_class=None)
saved.append(output_path)
except Exception as exc:
LOGGER.warning("Failed Grad-CAM for %s: %s", row.filepath, exc)
return saved
def gradcam_for_checkpoint(
model_path: str | Path,
image: str | Path | Image.Image,
config: dict[str, Any],
output_path: str | Path | None = None,
) -> Image.Image:
checkpoint = load_torch_checkpoint(model_path, map_location="cpu")
if checkpoint.get("family", "cnn") != "cnn":
raise ValueError("Grad-CAM is only enabled for CNN deep-learning checkpoints.")
model = create_model(checkpoint["model_key"], checkpoint.get("config", config), pretrained=False)
model.load_state_dict(checkpoint["state_dict"])
return gradcam_overlay(model, image, checkpoint.get("config", config), output_path=output_path)
|