DhruvB1906's picture
Upload folder using huggingface_hub
4e9a3bc verified
"""Grad-CAM for CNN spectrogram branch."""
import logging
import numpy as np
import torch
import torch.nn.functional as F
logger = logging.getLogger(__name__)
class GradCAM:
"""
Gradient-weighted Class Activation Mapping (Grad-CAM).
Highlights which regions of the spectrogram contributed to the prediction.
"""
def __init__(self, model, target_layer_name="cnn.conv_blocks.2"):
"""
Initialize Grad-CAM.
Args:
model: PyTorch model
target_layer_name: Name of target convolutional layer
"""
self.model = model
self.target_layer_name = target_layer_name
self.gradients = None
self.activations = None
# Register hooks
self._register_hooks()
def _register_hooks(self):
"""Register forward and backward hooks."""
def forward_hook(module, input, output):
self.activations = output.detach()
def backward_hook(module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
# Find target layer
target_layer = self._find_layer(self.model, self.target_layer_name)
if target_layer is not None:
target_layer.register_forward_hook(forward_hook)
target_layer.register_full_backward_hook(backward_hook)
logger.info(f"Grad-CAM hooks registered on: {self.target_layer_name}")
else:
logger.warning(f"Target layer {self.target_layer_name} not found")
def _find_layer(self, model, layer_name):
"""Find layer by name."""
for name, module in model.named_modules():
if name == layer_name:
return module
return None
def generate_heatmap(
self,
input_tensor: torch.Tensor,
target_class: int = None,
) -> np.ndarray:
"""
Generate Grad-CAM heatmap.
Args:
input_tensor: Input spectrogram (batch, channels, freq, time)
target_class: Target class index (if None, use predicted class)
Returns:
Heatmap array (freq, time) with values in [0, 1]
"""
self.model.eval()
# Forward pass
output = self.model(input_tensor)
# Use predicted class if not specified
if target_class is None:
target_class = output.argmax(dim=1).item()
# Backward pass
self.model.zero_grad()
one_hot = torch.zeros_like(output)
one_hot[0, target_class] = 1
output.backward(gradient=one_hot, retain_graph=True)
# Compute weights
weights = self.gradients.mean(dim=(2, 3), keepdim=True) # Global average pooling
# Weighted combination
cam = (weights * self.activations).sum(dim=1, keepdim=True) # (1, 1, freq, time)
# ReLU and normalize
cam = F.relu(cam)
cam = cam.squeeze().cpu().numpy()
# Normalize to [0, 1]
if cam.max() > 0:
cam = (cam - cam.min()) / (cam.max() - cam.min())
logger.info(f"Grad-CAM heatmap generated: shape={cam.shape}")
return cam
def overlay_heatmap(
self,
spectrogram: np.ndarray,
heatmap: np.ndarray,
alpha: float = 0.5,
) -> np.ndarray:
"""
Overlay heatmap on spectrogram.
Args:
spectrogram: Original spectrogram (freq, time)
heatmap: Grad-CAM heatmap (freq, time)
alpha: Transparency (0=spectrogram only, 1=heatmap only)
Returns:
Overlayed visualization
"""
# Resize heatmap to match spectrogram if needed
if heatmap.shape != spectrogram.shape:
from scipy.ndimage import zoom
zoom_factors = (
spectrogram.shape[0] / heatmap.shape[0],
spectrogram.shape[1] / heatmap.shape[1],
)
heatmap = zoom(heatmap, zoom_factors, order=1)
# Normalize spectrogram to [0, 1]
spec_norm = (spectrogram - spectrogram.min()) / (spectrogram.max() - spectrogram.min())
# Blend
overlay = alpha * heatmap + (1 - alpha) * spec_norm
return overlay