File size: 5,933 Bytes
1264815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from __future__ import annotations

from pathlib import Path
from typing import Any

import cv2
import numpy as np
import pandas as pd
import torch
from PIL import Image, ImageOps
from torch import nn

from .augmentations import IMAGENET_MEAN, IMAGENET_STD, build_eval_transform
from .compare_models import load_best_model_record
from .dl_models import create_model, load_torch_checkpoint
from .paths import ensure_dir
from .preprocessing import load_pil_image
from .utils import get_logger


LOGGER = get_logger(__name__)


def find_last_conv_layer(model: nn.Module) -> nn.Module | None:
    last_conv: nn.Module | None = None
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            last_conv = module
    return last_conv


def denormalize_tensor(image_tensor: torch.Tensor) -> np.ndarray:
    mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
    std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
    image = image_tensor.detach().cpu() * std + mean
    image = image.clamp(0, 1).permute(1, 2, 0).numpy()
    return (image * 255).astype(np.uint8)


def gradcam_overlay(
    model: nn.Module,
    image: str | Path | Image.Image,
    config: dict[str, Any],
    output_path: str | Path | None = None,
    target_class: int | None = None,
    device: torch.device | None = None,
) -> Image.Image:
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device).eval()
    target_layer = find_last_conv_layer(model)
    if target_layer is None:
        raise ValueError("Grad-CAM is only available for CNN models with Conv2d layers.")

    activation: dict[str, torch.Tensor] = {}

    def forward_hook(_: nn.Module, __: tuple[torch.Tensor, ...], output: torch.Tensor) -> torch.Tensor:
        cloned = output.clone()
        cloned.retain_grad()
        activation["value"] = cloned
        return cloned

    handle_fwd = target_layer.register_forward_hook(forward_hook)
    try:
        pil = load_pil_image(image, mode="RGB")
        transform = build_eval_transform(config)
        tensor = transform(pil).unsqueeze(0).to(device)
        tensor.requires_grad_(True)
        logits = model(tensor)
        class_idx = int(target_class if target_class is not None else torch.argmax(logits, dim=1).item())
        model.zero_grad(set_to_none=True)
        logits[:, class_idx].sum().backward()
        acts_with_grad = activation.get("value")
        if acts_with_grad is None or acts_with_grad.grad is None:
            raise RuntimeError("Grad-CAM hook did not capture activations/gradients.")
        acts = acts_with_grad.detach()[0]
        grads = acts_with_grad.grad.detach()[0]
        weights = grads.mean(dim=(1, 2), keepdim=True)
        cam = (weights * acts).sum(dim=0)
        cam = torch.relu(cam)
        cam -= cam.min()
        cam /= cam.max().clamp(min=1e-8)
        cam_np = cam.detach().cpu().numpy()
        base = denormalize_tensor(tensor[0])
        heatmap = cv2.resize(cam_np, (base.shape[1], base.shape[0]))
        heatmap_uint8 = np.uint8(255 * heatmap)
        color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
        color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB)
        overlay = np.uint8(0.55 * base + 0.45 * color)
        out = Image.fromarray(overlay)
        if output_path:
            output_path = Path(output_path)
            output_path.parent.mkdir(parents=True, exist_ok=True)
            out.save(output_path)
        return out
    finally:
        handle_fwd.remove()


def save_gradcam_examples_for_best(
    config: dict[str, Any],
    splits_df: pd.DataFrame,
    leaderboard_df: pd.DataFrame | None = None,
) -> list[Path]:
    record = (
        leaderboard_df.iloc[0].to_dict()
        if leaderboard_df is not None and not leaderboard_df.empty
        else load_best_model_record(config)
    )
    if record.get("model_type") != "deep_learning":
        LOGGER.info("Best model is not a deep CNN; Grad-CAM generation skipped.")
        return []
    model_path = Path(record["model_path"])
    checkpoint = load_torch_checkpoint(model_path, map_location="cpu")
    family = checkpoint.get("family", "cnn")
    if family != "cnn":
        LOGGER.info("Best deep model is %s; Grad-CAM is skipped for non-CNN models.", family)
        return []
    model = create_model(checkpoint["model_key"], checkpoint.get("config", config), pretrained=False)
    model.load_state_dict(checkpoint["state_dict"])
    test_df = splits_df[splits_df["split"] == "test"].copy().reset_index(drop=True)
    if test_df.empty:
        return []
    count = int(config.get("explainability", {}).get("max_images", 8))
    sample = test_df.groupby("label", group_keys=False).head(max(1, count // 2)).head(count)
    out_dir = ensure_dir(Path(config["paths"]["output_dir"]) / "plots" / "gradcam")
    saved: list[Path] = []
    for idx, row in enumerate(sample.itertuples(index=False), start=1):
        output_path = out_dir / f"{record['model_name']}_gradcam_{idx:02d}.png"
        try:
            gradcam_overlay(model, row.filepath, checkpoint.get("config", config), output_path, target_class=None)
            saved.append(output_path)
        except Exception as exc:
            LOGGER.warning("Failed Grad-CAM for %s: %s", row.filepath, exc)
    return saved


def gradcam_for_checkpoint(
    model_path: str | Path,
    image: str | Path | Image.Image,
    config: dict[str, Any],
    output_path: str | Path | None = None,
) -> Image.Image:
    checkpoint = load_torch_checkpoint(model_path, map_location="cpu")
    if checkpoint.get("family", "cnn") != "cnn":
        raise ValueError("Grad-CAM is only enabled for CNN deep-learning checkpoints.")
    model = create_model(checkpoint["model_key"], checkpoint.get("config", config), pretrained=False)
    model.load_state_dict(checkpoint["state_dict"])
    return gradcam_overlay(model, image, checkpoint.get("config", config), output_path=output_path)