from __future__ import annotations from pathlib import Path from typing import Any import cv2 import numpy as np import pandas as pd import torch from PIL import Image, ImageOps from torch import nn from .augmentations import IMAGENET_MEAN, IMAGENET_STD, build_eval_transform from .compare_models import load_best_model_record from .dl_models import create_model, load_torch_checkpoint from .paths import ensure_dir from .preprocessing import load_pil_image from .utils import get_logger LOGGER = get_logger(__name__) def find_last_conv_layer(model: nn.Module) -> nn.Module | None: last_conv: nn.Module | None = None for module in model.modules(): if isinstance(module, nn.Conv2d): last_conv = module return last_conv def denormalize_tensor(image_tensor: torch.Tensor) -> np.ndarray: mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1) std = torch.tensor(IMAGENET_STD).view(3, 1, 1) image = image_tensor.detach().cpu() * std + mean image = image.clamp(0, 1).permute(1, 2, 0).numpy() return (image * 255).astype(np.uint8) def gradcam_overlay( model: nn.Module, image: str | Path | Image.Image, config: dict[str, Any], output_path: str | Path | None = None, target_class: int | None = None, device: torch.device | None = None, ) -> Image.Image: device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device).eval() target_layer = find_last_conv_layer(model) if target_layer is None: raise ValueError("Grad-CAM is only available for CNN models with Conv2d layers.") activation: dict[str, torch.Tensor] = {} def forward_hook(_: nn.Module, __: tuple[torch.Tensor, ...], output: torch.Tensor) -> torch.Tensor: cloned = output.clone() cloned.retain_grad() activation["value"] = cloned return cloned handle_fwd = target_layer.register_forward_hook(forward_hook) try: pil = load_pil_image(image, mode="RGB") transform = build_eval_transform(config) tensor = transform(pil).unsqueeze(0).to(device) tensor.requires_grad_(True) logits = model(tensor) class_idx = int(target_class if target_class is not None else torch.argmax(logits, dim=1).item()) model.zero_grad(set_to_none=True) logits[:, class_idx].sum().backward() acts_with_grad = activation.get("value") if acts_with_grad is None or acts_with_grad.grad is None: raise RuntimeError("Grad-CAM hook did not capture activations/gradients.") acts = acts_with_grad.detach()[0] grads = acts_with_grad.grad.detach()[0] weights = grads.mean(dim=(1, 2), keepdim=True) cam = (weights * acts).sum(dim=0) cam = torch.relu(cam) cam -= cam.min() cam /= cam.max().clamp(min=1e-8) cam_np = cam.detach().cpu().numpy() base = denormalize_tensor(tensor[0]) heatmap = cv2.resize(cam_np, (base.shape[1], base.shape[0])) heatmap_uint8 = np.uint8(255 * heatmap) color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB) overlay = np.uint8(0.55 * base + 0.45 * color) out = Image.fromarray(overlay) if output_path: output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) out.save(output_path) return out finally: handle_fwd.remove() def save_gradcam_examples_for_best( config: dict[str, Any], splits_df: pd.DataFrame, leaderboard_df: pd.DataFrame | None = None, ) -> list[Path]: record = ( leaderboard_df.iloc[0].to_dict() if leaderboard_df is not None and not leaderboard_df.empty else load_best_model_record(config) ) if record.get("model_type") != "deep_learning": LOGGER.info("Best model is not a deep CNN; Grad-CAM generation skipped.") return [] model_path = Path(record["model_path"]) checkpoint = load_torch_checkpoint(model_path, map_location="cpu") family = checkpoint.get("family", "cnn") if family != "cnn": LOGGER.info("Best deep model is %s; Grad-CAM is skipped for non-CNN models.", family) return [] model = create_model(checkpoint["model_key"], checkpoint.get("config", config), pretrained=False) model.load_state_dict(checkpoint["state_dict"]) test_df = splits_df[splits_df["split"] == "test"].copy().reset_index(drop=True) if test_df.empty: return [] count = int(config.get("explainability", {}).get("max_images", 8)) sample = test_df.groupby("label", group_keys=False).head(max(1, count // 2)).head(count) out_dir = ensure_dir(Path(config["paths"]["output_dir"]) / "plots" / "gradcam") saved: list[Path] = [] for idx, row in enumerate(sample.itertuples(index=False), start=1): output_path = out_dir / f"{record['model_name']}_gradcam_{idx:02d}.png" try: gradcam_overlay(model, row.filepath, checkpoint.get("config", config), output_path, target_class=None) saved.append(output_path) except Exception as exc: LOGGER.warning("Failed Grad-CAM for %s: %s", row.filepath, exc) return saved def gradcam_for_checkpoint( model_path: str | Path, image: str | Path | Image.Image, config: dict[str, Any], output_path: str | Path | None = None, ) -> Image.Image: checkpoint = load_torch_checkpoint(model_path, map_location="cpu") if checkpoint.get("family", "cnn") != "cnn": raise ValueError("Grad-CAM is only enabled for CNN deep-learning checkpoints.") model = create_model(checkpoint["model_key"], checkpoint.get("config", config), pretrained=False) model.load_state_dict(checkpoint["state_dict"]) return gradcam_overlay(model, image, checkpoint.get("config", config), output_path=output_path)