adityasync's picture
Clean OncoVision-X deployment with LFS
8960670
#!/usr/bin/env python3
"""
3D GradCAM for OncoVision-X Lung Nodule Classification.
Generates gradient-weighted class activation maps to visualize
which regions of the input the model focuses on for its predictions.
Usage:
from src.explainability.gradcam import GradCAM3D
gradcam = GradCAM3D(model, target_layer='nodule_stream')
heatmap = gradcam(nodule_patch, context_patch)
"""
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pathlib import Path
class GradCAM3D:
"""3D Gradient-weighted Class Activation Mapping.
Generates heatmaps showing which spatial regions of the input
volume contribute most to the model's prediction.
Supports two target streams:
- 'nodule_stream': Visualize nodule patch focus (64Β³)
- 'context_stream': Visualize context patch focus (48Β³)
"""
def __init__(self, model, target_stream='nodule_stream'):
"""
Args:
model: DCANet model (unwrapped from DataParallel)
target_stream: 'nodule_stream' or 'context_stream'
"""
self.model = model
self.model.eval()
self.target_stream = target_stream
# Storage for hooks
self.activations = None
self.gradients = None
# Register hooks on the target stream's last conv layer
if target_stream == 'nodule_stream':
# Hook into the backbone's final feature extraction
target_layer = self._get_nodule_target()
elif target_stream == 'context_stream':
target_layer = self._get_context_target()
else:
raise ValueError(f"Unknown target_stream: {target_stream}")
target_layer.register_forward_hook(self._save_activation)
target_layer.register_full_backward_hook(self._save_gradient)
def _get_nodule_target(self):
"""Get the last convolutional layer of the nodule stream backbone."""
backbone = self.model.nodule_stream.backbone
# For EfficientNet, the last conv features are in the final block
# Use the conv_head or the last conv layer
if hasattr(backbone, 'conv_head'):
return backbone.conv_head
elif hasattr(backbone, 'blocks'):
# Last block's last conv
return backbone.blocks[-1]
else:
# Fallback: last named module with Conv2d
last_conv = None
for module in backbone.modules():
if isinstance(module, torch.nn.Conv2d):
last_conv = module
if last_conv is None:
raise RuntimeError("Could not find target layer in nodule backbone")
return last_conv
def _get_context_target(self):
"""Get the last conv block of the context stream."""
return self.model.context_stream.block3
def _save_activation(self, module, input, output):
"""Forward hook to save activations."""
self.activations = output.detach()
def _save_gradient(self, module, grad_input, grad_output):
"""Backward hook to save gradients."""
self.gradients = grad_output[0].detach()
def generate(self, nodule_patch, context_patch, device='cpu'):
"""Generate GradCAM heatmap.
Args:
nodule_patch: numpy array (64, 64, 64) or (1, 1, 64, 64, 64)
context_patch: numpy array (48, 48, 48) or (1, 1, 48, 48, 48)
device: torch device
Returns:
heatmap: numpy array, same spatial dims as target stream input
probability: model prediction probability
"""
# Ensure tensor format
if isinstance(nodule_patch, np.ndarray):
if nodule_patch.ndim == 3:
nodule_patch = nodule_patch[np.newaxis, np.newaxis, ...]
nodule_patch = torch.from_numpy(nodule_patch.astype(np.float32))
if isinstance(context_patch, np.ndarray):
if context_patch.ndim == 3:
context_patch = context_patch[np.newaxis, np.newaxis, ...]
context_patch = torch.from_numpy(context_patch.astype(np.float32))
nodule_patch = nodule_patch.to(device).requires_grad_(True)
context_patch = context_patch.to(device).requires_grad_(True)
# Forward pass
self.model.zero_grad()
logits = self.model(nodule_patch, context_patch)
prob = torch.sigmoid(logits.squeeze()).item()
# Backward pass β€” gradient w.r.t. the class score
logits.squeeze().backward()
# Compute GradCAM
if self.gradients is None or self.activations is None:
raise RuntimeError("Hooks did not capture gradients/activations")
gradients = self.gradients
activations = self.activations
# For nodule stream: activations are 2D per-slice (B*D, C, H, W)
# For context stream: activations are 3D (B, C, D, H, W)
if self.target_stream == 'nodule_stream':
# Global average pooling of gradients over spatial dims
weights = gradients.mean(dim=(-2, -1), keepdim=True) # (N, C, 1, 1)
cam = (weights * activations).sum(dim=1) # (N, H, W)
cam = F.relu(cam)
# Reshape to 3D: N slices β†’ (D, H, W)
D = nodule_patch.shape[2] # depth
num_per_slice = cam.shape[0] // nodule_patch.shape[0]
cam = cam[:num_per_slice] # Take first batch item
# Resize each slice to original spatial size
target_h, target_w = nodule_patch.shape[3], nodule_patch.shape[4]
cam_resized = F.interpolate(
cam.unsqueeze(1), # (D, 1, h, w)
size=(target_h, target_w),
mode='bilinear', align_corners=False
).squeeze(1) # (D, H, W)
heatmap = cam_resized.cpu().numpy()
else:
# Context stream: 3D activations
weights = gradients.mean(dim=(-3, -2, -1), keepdim=True)
cam = (weights * activations).sum(dim=1, keepdim=True)
cam = F.relu(cam)
target_size = context_patch.shape[2:]
cam = F.interpolate(cam, size=target_size, mode='trilinear', align_corners=False)
heatmap = cam.squeeze().cpu().numpy()
# Normalize to [0, 1]
if heatmap.max() > 0:
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
return heatmap, prob
def __call__(self, nodule_patch, context_patch, device='cpu'):
return self.generate(nodule_patch, context_patch, device)
def plot_gradcam_slices(scan_patch, heatmap, probability, output_path,
num_slices=8, title="GradCAM Visualization"):
"""Plot GradCAM overlay on selected slices.
Args:
scan_patch: 3D numpy array (D, H, W) β€” original input
heatmap: 3D numpy array (D, H, W) β€” GradCAM heatmap
probability: float β€” model prediction
output_path: str β€” save path
num_slices: int β€” number of slices to display
title: str β€” plot title
"""
D = scan_patch.shape[0]
slice_indices = np.linspace(D // 8, D - D // 8, num_slices, dtype=int)
fig, axes = plt.subplots(2, num_slices, figsize=(3 * num_slices, 7))
label = "MALIGNANT" if probability > 0.5 else "BENIGN"
color = 'red' if probability > 0.5 else 'green'
fig.suptitle(
f"{title}\nPrediction: {label} ({probability:.1%})",
fontsize=14, fontweight='bold', color=color
)
for i, idx in enumerate(slice_indices):
# Original slice
axes[0, i].imshow(scan_patch[idx], cmap='gray', vmin=-1, vmax=1)
axes[0, i].set_title(f"Slice {idx}", fontsize=9)
axes[0, i].axis('off')
# Overlay
axes[1, i].imshow(scan_patch[idx], cmap='gray', vmin=-1, vmax=1)
overlay = axes[1, i].imshow(
heatmap[idx], cmap='jet', alpha=0.5, vmin=0, vmax=1
)
axes[1, i].set_title(f"GradCAM", fontsize=9)
axes[1, i].axis('off')
# Colorbar
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.35])
fig.colorbar(overlay, cax=cbar_ax, label='Attention')
plt.tight_layout(rect=[0, 0, 0.9, 0.92])
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
return output_path
def generate_gradcam_report(model, dataloader, device, output_dir,
num_samples=10, stream='nodule_stream'):
"""Generate GradCAM visualizations for multiple samples.
Args:
model: Trained DCANet model
dataloader: Test DataLoader
device: torch device
output_dir: Directory to save GradCAM plots
num_samples: Number of samples to visualize
stream: Which stream to visualize
Returns:
List of output paths
"""
import torch.nn as nn
# Unwrap DataParallel if needed
if isinstance(model, nn.DataParallel):
raw_model = model.module
else:
raw_model = model
gradcam = GradCAM3D(raw_model, target_stream=stream)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
outputs = []
sample_count = 0
for nodule, context, labels in dataloader:
for i in range(nodule.shape[0]):
if sample_count >= num_samples:
return outputs
nod = nodule[i:i+1]
ctx = context[i:i+1]
label = labels[i].item()
try:
heatmap, prob = gradcam.generate(nod, ctx, device)
except Exception as e:
print(f" GradCAM failed for sample {sample_count}: {e}")
continue
# Get the original scan patch for overlay
scan_slice = nod.squeeze().numpy()
# Crop heatmap to match scan if needed
if heatmap.shape != scan_slice.shape:
min_d = min(heatmap.shape[0], scan_slice.shape[0])
heatmap = heatmap[:min_d]
scan_slice = scan_slice[:min_d]
gt_str = "pos" if label == 1 else "neg"
pred_str = "malignant" if prob > 0.5 else "benign"
correct = (label == 1 and prob > 0.5) or (label == 0 and prob <= 0.5)
out_path = output_dir / f"gradcam_{sample_count:03d}_{gt_str}_pred_{pred_str}.png"
plot_gradcam_slices(
scan_slice, heatmap, prob, str(out_path),
title=f"Sample {sample_count} β€” GT: {'Cancer' if label==1 else 'Benign'} | "
f"{'Correct' if correct else 'WRONG'}"
)
outputs.append(str(out_path))
sample_count += 1
return outputs