Spaces:
Sleeping
Sleeping
| """Grad-CAM for CNN spectrogram branch.""" | |
| import logging | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| logger = logging.getLogger(__name__) | |
| class GradCAM: | |
| """ | |
| Gradient-weighted Class Activation Mapping (Grad-CAM). | |
| Highlights which regions of the spectrogram contributed to the prediction. | |
| """ | |
| def __init__(self, model, target_layer_name="cnn.conv_blocks.2"): | |
| """ | |
| Initialize Grad-CAM. | |
| Args: | |
| model: PyTorch model | |
| target_layer_name: Name of target convolutional layer | |
| """ | |
| self.model = model | |
| self.target_layer_name = target_layer_name | |
| self.gradients = None | |
| self.activations = None | |
| # Register hooks | |
| self._register_hooks() | |
| def _register_hooks(self): | |
| """Register forward and backward hooks.""" | |
| def forward_hook(module, input, output): | |
| self.activations = output.detach() | |
| def backward_hook(module, grad_input, grad_output): | |
| self.gradients = grad_output[0].detach() | |
| # Find target layer | |
| target_layer = self._find_layer(self.model, self.target_layer_name) | |
| if target_layer is not None: | |
| target_layer.register_forward_hook(forward_hook) | |
| target_layer.register_full_backward_hook(backward_hook) | |
| logger.info(f"Grad-CAM hooks registered on: {self.target_layer_name}") | |
| else: | |
| logger.warning(f"Target layer {self.target_layer_name} not found") | |
| def _find_layer(self, model, layer_name): | |
| """Find layer by name.""" | |
| for name, module in model.named_modules(): | |
| if name == layer_name: | |
| return module | |
| return None | |
| def generate_heatmap( | |
| self, | |
| input_tensor: torch.Tensor, | |
| target_class: int = None, | |
| ) -> np.ndarray: | |
| """ | |
| Generate Grad-CAM heatmap. | |
| Args: | |
| input_tensor: Input spectrogram (batch, channels, freq, time) | |
| target_class: Target class index (if None, use predicted class) | |
| Returns: | |
| Heatmap array (freq, time) with values in [0, 1] | |
| """ | |
| self.model.eval() | |
| # Forward pass | |
| output = self.model(input_tensor) | |
| # Use predicted class if not specified | |
| if target_class is None: | |
| target_class = output.argmax(dim=1).item() | |
| # Backward pass | |
| self.model.zero_grad() | |
| one_hot = torch.zeros_like(output) | |
| one_hot[0, target_class] = 1 | |
| output.backward(gradient=one_hot, retain_graph=True) | |
| # Compute weights | |
| weights = self.gradients.mean(dim=(2, 3), keepdim=True) # Global average pooling | |
| # Weighted combination | |
| cam = (weights * self.activations).sum(dim=1, keepdim=True) # (1, 1, freq, time) | |
| # ReLU and normalize | |
| cam = F.relu(cam) | |
| cam = cam.squeeze().cpu().numpy() | |
| # Normalize to [0, 1] | |
| if cam.max() > 0: | |
| cam = (cam - cam.min()) / (cam.max() - cam.min()) | |
| logger.info(f"Grad-CAM heatmap generated: shape={cam.shape}") | |
| return cam | |
| def overlay_heatmap( | |
| self, | |
| spectrogram: np.ndarray, | |
| heatmap: np.ndarray, | |
| alpha: float = 0.5, | |
| ) -> np.ndarray: | |
| """ | |
| Overlay heatmap on spectrogram. | |
| Args: | |
| spectrogram: Original spectrogram (freq, time) | |
| heatmap: Grad-CAM heatmap (freq, time) | |
| alpha: Transparency (0=spectrogram only, 1=heatmap only) | |
| Returns: | |
| Overlayed visualization | |
| """ | |
| # Resize heatmap to match spectrogram if needed | |
| if heatmap.shape != spectrogram.shape: | |
| from scipy.ndimage import zoom | |
| zoom_factors = ( | |
| spectrogram.shape[0] / heatmap.shape[0], | |
| spectrogram.shape[1] / heatmap.shape[1], | |
| ) | |
| heatmap = zoom(heatmap, zoom_factors, order=1) | |
| # Normalize spectrogram to [0, 1] | |
| spec_norm = (spectrogram - spectrogram.min()) / (spectrogram.max() - spectrogram.min()) | |
| # Blend | |
| overlay = alpha * heatmap + (1 - alpha) * spec_norm | |
| return overlay | |