| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Any |
|
|
| import cv2 |
| import numpy as np |
| import pandas as pd |
| import torch |
| from PIL import Image, ImageOps |
| from torch import nn |
|
|
| from .augmentations import IMAGENET_MEAN, IMAGENET_STD, build_eval_transform |
| from .compare_models import load_best_model_record |
| from .dl_models import create_model, load_torch_checkpoint |
| from .paths import ensure_dir |
| from .preprocessing import load_pil_image |
| from .utils import get_logger |
|
|
|
|
| LOGGER = get_logger(__name__) |
|
|
|
|
| def find_last_conv_layer(model: nn.Module) -> nn.Module | None: |
| last_conv: nn.Module | None = None |
| for module in model.modules(): |
| if isinstance(module, nn.Conv2d): |
| last_conv = module |
| return last_conv |
|
|
|
|
| def denormalize_tensor(image_tensor: torch.Tensor) -> np.ndarray: |
| mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1) |
| std = torch.tensor(IMAGENET_STD).view(3, 1, 1) |
| image = image_tensor.detach().cpu() * std + mean |
| image = image.clamp(0, 1).permute(1, 2, 0).numpy() |
| return (image * 255).astype(np.uint8) |
|
|
|
|
| def gradcam_overlay( |
| model: nn.Module, |
| image: str | Path | Image.Image, |
| config: dict[str, Any], |
| output_path: str | Path | None = None, |
| target_class: int | None = None, |
| device: torch.device | None = None, |
| ) -> Image.Image: |
| device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = model.to(device).eval() |
| target_layer = find_last_conv_layer(model) |
| if target_layer is None: |
| raise ValueError("Grad-CAM is only available for CNN models with Conv2d layers.") |
|
|
| activation: dict[str, torch.Tensor] = {} |
|
|
| def forward_hook(_: nn.Module, __: tuple[torch.Tensor, ...], output: torch.Tensor) -> torch.Tensor: |
| cloned = output.clone() |
| cloned.retain_grad() |
| activation["value"] = cloned |
| return cloned |
|
|
| handle_fwd = target_layer.register_forward_hook(forward_hook) |
| try: |
| pil = load_pil_image(image, mode="RGB") |
| transform = build_eval_transform(config) |
| tensor = transform(pil).unsqueeze(0).to(device) |
| tensor.requires_grad_(True) |
| logits = model(tensor) |
| class_idx = int(target_class if target_class is not None else torch.argmax(logits, dim=1).item()) |
| model.zero_grad(set_to_none=True) |
| logits[:, class_idx].sum().backward() |
| acts_with_grad = activation.get("value") |
| if acts_with_grad is None or acts_with_grad.grad is None: |
| raise RuntimeError("Grad-CAM hook did not capture activations/gradients.") |
| acts = acts_with_grad.detach()[0] |
| grads = acts_with_grad.grad.detach()[0] |
| weights = grads.mean(dim=(1, 2), keepdim=True) |
| cam = (weights * acts).sum(dim=0) |
| cam = torch.relu(cam) |
| cam -= cam.min() |
| cam /= cam.max().clamp(min=1e-8) |
| cam_np = cam.detach().cpu().numpy() |
| base = denormalize_tensor(tensor[0]) |
| heatmap = cv2.resize(cam_np, (base.shape[1], base.shape[0])) |
| heatmap_uint8 = np.uint8(255 * heatmap) |
| color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) |
| color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB) |
| overlay = np.uint8(0.55 * base + 0.45 * color) |
| out = Image.fromarray(overlay) |
| if output_path: |
| output_path = Path(output_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| out.save(output_path) |
| return out |
| finally: |
| handle_fwd.remove() |
|
|
|
|
| def save_gradcam_examples_for_best( |
| config: dict[str, Any], |
| splits_df: pd.DataFrame, |
| leaderboard_df: pd.DataFrame | None = None, |
| ) -> list[Path]: |
| record = ( |
| leaderboard_df.iloc[0].to_dict() |
| if leaderboard_df is not None and not leaderboard_df.empty |
| else load_best_model_record(config) |
| ) |
| if record.get("model_type") != "deep_learning": |
| LOGGER.info("Best model is not a deep CNN; Grad-CAM generation skipped.") |
| return [] |
| model_path = Path(record["model_path"]) |
| checkpoint = load_torch_checkpoint(model_path, map_location="cpu") |
| family = checkpoint.get("family", "cnn") |
| if family != "cnn": |
| LOGGER.info("Best deep model is %s; Grad-CAM is skipped for non-CNN models.", family) |
| return [] |
| model = create_model(checkpoint["model_key"], checkpoint.get("config", config), pretrained=False) |
| model.load_state_dict(checkpoint["state_dict"]) |
| test_df = splits_df[splits_df["split"] == "test"].copy().reset_index(drop=True) |
| if test_df.empty: |
| return [] |
| count = int(config.get("explainability", {}).get("max_images", 8)) |
| sample = test_df.groupby("label", group_keys=False).head(max(1, count // 2)).head(count) |
| out_dir = ensure_dir(Path(config["paths"]["output_dir"]) / "plots" / "gradcam") |
| saved: list[Path] = [] |
| for idx, row in enumerate(sample.itertuples(index=False), start=1): |
| output_path = out_dir / f"{record['model_name']}_gradcam_{idx:02d}.png" |
| try: |
| gradcam_overlay(model, row.filepath, checkpoint.get("config", config), output_path, target_class=None) |
| saved.append(output_path) |
| except Exception as exc: |
| LOGGER.warning("Failed Grad-CAM for %s: %s", row.filepath, exc) |
| return saved |
|
|
|
|
| def gradcam_for_checkpoint( |
| model_path: str | Path, |
| image: str | Path | Image.Image, |
| config: dict[str, Any], |
| output_path: str | Path | None = None, |
| ) -> Image.Image: |
| checkpoint = load_torch_checkpoint(model_path, map_location="cpu") |
| if checkpoint.get("family", "cnn") != "cnn": |
| raise ValueError("Grad-CAM is only enabled for CNN deep-learning checkpoints.") |
| model = create_model(checkpoint["model_key"], checkpoint.get("config", config), pretrained=False) |
| model.load_state_dict(checkpoint["state_dict"]) |
| return gradcam_overlay(model, image, checkpoint.get("config", config), output_path=output_path) |
|
|