File size: 5,738 Bytes
c1969c3
 
 
 
 
 
 
 
 
 
0e8457d
 
c1969c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91347e8
 
 
 
 
 
 
 
 
c1969c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91347e8
 
 
c1969c3
 
0e8457d
c1969c3
 
0e8457d
c1969c3
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""Visualization utilities: Grad-CAM, ROC curves, training plots."""

from __future__ import annotations

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as functional
from sklearn.metrics import auc, roc_curve

from src.data.dataset import PATHOLOGY_LABELS


def plot_roc_curves(
    y_true: np.ndarray,
    y_prob: np.ndarray,
    labels: list[str] = PATHOLOGY_LABELS,
    save_path: Path | None = None,
) -> plt.Figure:
    """Plot ROC curves for each pathology class."""
    fig, ax = plt.subplots(figsize=(10, 8))

    for i, label in enumerate(labels):
        if y_true[:, i].sum() > 0:
            fpr, tpr, _ = roc_curve(y_true[:, i], y_prob[:, i])
            roc_auc = auc(fpr, tpr)
            ax.plot(fpr, tpr, label=f"{label} (AUC={roc_auc:.3f})")

    ax.plot([0, 1], [0, 1], "k--", alpha=0.5)
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_title("ROC Curves — Multi-Label Classification")
    ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)
    plt.tight_layout()

    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches="tight")

    return fig


def plot_confusion_matrix(
    cm: np.ndarray,
    labels: list[str] = ["Normal", "Abnormal"],
    save_path: Path | None = None,
) -> plt.Figure:
    """Plot confusion matrix for binary classification."""
    fig, ax = plt.subplots(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels, ax=ax)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("Actual")
    ax.set_title("Confusion Matrix — Binary Classification")
    plt.tight_layout()

    if save_path:
        fig.savefig(save_path, dpi=150)

    return fig


def plot_training_history(
    history: dict[str, list[float]],
    save_path: Path | None = None,
) -> plt.Figure:
    """Plot training and validation loss/metric curves."""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    # Loss curves
    if "train_combined_loss" in history:
        axes[0].plot(history["train_combined_loss"], label="Train Loss")
    if "val_multilabel_loss" in history:
        axes[0].plot(history["val_multilabel_loss"], label="Val ML Loss")
    axes[0].set_title("Loss")
    axes[0].set_xlabel("Epoch")
    axes[0].legend()

    # AUC-ROC
    if "auc_roc_macro" in history:
        axes[1].plot(history["auc_roc_macro"], label="Macro AUC-ROC", color="green")
    axes[1].set_title("Multi-Label AUC-ROC")
    axes[1].set_xlabel("Epoch")
    axes[1].legend()

    # Binary metrics
    if "binary_auc_roc" in history:
        axes[2].plot(history["binary_auc_roc"], label="Binary AUC-ROC", color="orange")
    if "binary_f1" in history:
        axes[2].plot(history["binary_f1"], label="Binary F1", color="red")
    axes[2].set_title("Binary Classification")
    axes[2].set_xlabel("Epoch")
    axes[2].legend()

    plt.tight_layout()

    if save_path:
        fig.savefig(save_path, dpi=150)

    return fig


class GradCAM:
    """Grad-CAM implementation for model interpretability.

    Generates heatmaps showing which image regions the model focuses on for its predictions.
    Useful for verifying that the model attends to clinically relevant areas.
    """

    def __init__(self, model: torch.nn.Module, target_layer: torch.nn.Module) -> None:
        self.model = model
        self.target_layer = target_layer
        self.gradients: torch.Tensor | None = None
        self.activations: torch.Tensor | None = None

        # Register hooks
        target_layer.register_forward_hook(self._forward_hook)
        target_layer.register_full_backward_hook(self._backward_hook)

    def _forward_hook(self, module: torch.nn.Module, input: tuple, output: torch.Tensor) -> None:
        self.activations = output.detach()

    def _backward_hook(
        self,
        module: torch.nn.Module,
        grad_input: tuple[torch.Tensor, ...] | torch.Tensor,
        grad_output: tuple[torch.Tensor, ...] | torch.Tensor,
    ) -> tuple[torch.Tensor, ...] | torch.Tensor | None:
        g = grad_output[0] if isinstance(grad_output, tuple) else grad_output
        self.gradients = g.detach()
        return None

    def generate(self, image: torch.Tensor, class_idx: int, task: str = "multilabel") -> np.ndarray:
        """Generate Grad-CAM heatmap for a given class.

        Args:
            image: Input image tensor (1, C, H, W).
            class_idx: Target class index.
            task: "multilabel" or "binary".

        Returns:
            Heatmap as numpy array (H, W), values in [0, 1].
        """
        self.model.eval()
        output = self.model(image)

        logits_key = f"{task}_logits"
        if task == "binary":
            target = output[logits_key].squeeze()
        else:
            target = output[logits_key][0, class_idx]

        self.model.zero_grad()
        target.backward()

        # Global average pooling of gradients
        assert self.gradients is not None and self.activations is not None, (
            "Gradients/activations not captured — ensure backward() was called after generate()"
        )
        weights = self.gradients.mean(dim=[2, 3], keepdim=True)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = functional.relu(cam)

        # Resize to input image dimensions
        cam = functional.interpolate(cam, size=image.shape[2:], mode="bilinear", align_corners=False)
        cam = cam.squeeze().cpu().numpy()

        # Normalize to [0, 1]
        cam_min, cam_max = cam.min(), cam.max()
        if cam_max - cam_min > 0:
            cam = (cam - cam_min) / (cam_max - cam_min)

        return cam