chexvision-demo / src /utils /visualization.py
arudaev's picture
fix(types): resolve all 8 mypy errors across 5 files
91347e8
"""Visualization utilities: Grad-CAM, ROC curves, training plots."""
from __future__ import annotations
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as functional
from sklearn.metrics import auc, roc_curve
from src.data.dataset import PATHOLOGY_LABELS
def plot_roc_curves(
y_true: np.ndarray,
y_prob: np.ndarray,
labels: list[str] = PATHOLOGY_LABELS,
save_path: Path | None = None,
) -> plt.Figure:
"""Plot ROC curves for each pathology class."""
fig, ax = plt.subplots(figsize=(10, 8))
for i, label in enumerate(labels):
if y_true[:, i].sum() > 0:
fpr, tpr, _ = roc_curve(y_true[:, i], y_prob[:, i])
roc_auc = auc(fpr, tpr)
ax.plot(fpr, tpr, label=f"{label} (AUC={roc_auc:.3f})")
ax.plot([0, 1], [0, 1], "k--", alpha=0.5)
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.set_title("ROC Curves — Multi-Label Classification")
ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=8)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
return fig
def plot_confusion_matrix(
cm: np.ndarray,
labels: list[str] = ["Normal", "Abnormal"],
save_path: Path | None = None,
) -> plt.Figure:
"""Plot confusion matrix for binary classification."""
fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels, ax=ax)
ax.set_xlabel("Predicted")
ax.set_ylabel("Actual")
ax.set_title("Confusion Matrix — Binary Classification")
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150)
return fig
def plot_training_history(
history: dict[str, list[float]],
save_path: Path | None = None,
) -> plt.Figure:
"""Plot training and validation loss/metric curves."""
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# Loss curves
if "train_combined_loss" in history:
axes[0].plot(history["train_combined_loss"], label="Train Loss")
if "val_multilabel_loss" in history:
axes[0].plot(history["val_multilabel_loss"], label="Val ML Loss")
axes[0].set_title("Loss")
axes[0].set_xlabel("Epoch")
axes[0].legend()
# AUC-ROC
if "auc_roc_macro" in history:
axes[1].plot(history["auc_roc_macro"], label="Macro AUC-ROC", color="green")
axes[1].set_title("Multi-Label AUC-ROC")
axes[1].set_xlabel("Epoch")
axes[1].legend()
# Binary metrics
if "binary_auc_roc" in history:
axes[2].plot(history["binary_auc_roc"], label="Binary AUC-ROC", color="orange")
if "binary_f1" in history:
axes[2].plot(history["binary_f1"], label="Binary F1", color="red")
axes[2].set_title("Binary Classification")
axes[2].set_xlabel("Epoch")
axes[2].legend()
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150)
return fig
class GradCAM:
"""Grad-CAM implementation for model interpretability.
Generates heatmaps showing which image regions the model focuses on for its predictions.
Useful for verifying that the model attends to clinically relevant areas.
"""
def __init__(self, model: torch.nn.Module, target_layer: torch.nn.Module) -> None:
self.model = model
self.target_layer = target_layer
self.gradients: torch.Tensor | None = None
self.activations: torch.Tensor | None = None
# Register hooks
target_layer.register_forward_hook(self._forward_hook)
target_layer.register_full_backward_hook(self._backward_hook)
def _forward_hook(self, module: torch.nn.Module, input: tuple, output: torch.Tensor) -> None:
self.activations = output.detach()
def _backward_hook(
self,
module: torch.nn.Module,
grad_input: tuple[torch.Tensor, ...] | torch.Tensor,
grad_output: tuple[torch.Tensor, ...] | torch.Tensor,
) -> tuple[torch.Tensor, ...] | torch.Tensor | None:
g = grad_output[0] if isinstance(grad_output, tuple) else grad_output
self.gradients = g.detach()
return None
def generate(self, image: torch.Tensor, class_idx: int, task: str = "multilabel") -> np.ndarray:
"""Generate Grad-CAM heatmap for a given class.
Args:
image: Input image tensor (1, C, H, W).
class_idx: Target class index.
task: "multilabel" or "binary".
Returns:
Heatmap as numpy array (H, W), values in [0, 1].
"""
self.model.eval()
output = self.model(image)
logits_key = f"{task}_logits"
if task == "binary":
target = output[logits_key].squeeze()
else:
target = output[logits_key][0, class_idx]
self.model.zero_grad()
target.backward()
# Global average pooling of gradients
assert self.gradients is not None and self.activations is not None, (
"Gradients/activations not captured — ensure backward() was called after generate()"
)
weights = self.gradients.mean(dim=[2, 3], keepdim=True)
cam = (weights * self.activations).sum(dim=1, keepdim=True)
cam = functional.relu(cam)
# Resize to input image dimensions
cam = functional.interpolate(cam, size=image.shape[2:], mode="bilinear", align_corners=False)
cam = cam.squeeze().cpu().numpy()
# Normalize to [0, 1]
cam_min, cam_max = cam.min(), cam.max()
if cam_max - cam_min > 0:
cam = (cam - cam_min) / (cam_max - cam_min)
return cam