| """ |
| Explainability utilities for DeepFake detection models. |
| |
| Provides: |
| - GradCAM: For CNN-based models (EfficientNet, CompactGradientNet) |
| - Attention Rollout: For ViT/DeiT transformer models |
| - Heatmap visualization utilities |
| """ |
|
|
| import base64 |
| import io |
| import math |
| from typing import List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
|
|
| from app.core.logging import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class GradCAM: |
| """ |
| Gradient-weighted Class Activation Mapping for CNN models. |
| |
| Computes importance heatmaps by weighting feature map activations |
| by the gradients flowing into them from the target class. |
| |
| Usage: |
| gradcam = GradCAM(model, target_layer) |
| heatmap = gradcam(input_tensor, target_class=1) |
| """ |
| |
| def __init__(self, model: nn.Module, target_layer: nn.Module): |
| """ |
| Args: |
| model: The CNN model |
| target_layer: The convolutional layer to compute Grad-CAM on |
| (typically the last conv layer before pooling) |
| """ |
| self.model = model |
| self.target_layer = target_layer |
| self.gradients: Optional[torch.Tensor] = None |
| self.activations: Optional[torch.Tensor] = None |
| self._hooks: List = [] |
| |
| self._register_hooks() |
| |
| def _register_hooks(self): |
| """Register forward and backward hooks on target layer.""" |
| def forward_hook(module, input, output): |
| self.activations = output.detach() |
| |
| def backward_hook(module, grad_input, grad_output): |
| self.gradients = grad_output[0].detach() |
| |
| self._hooks.append( |
| self.target_layer.register_forward_hook(forward_hook) |
| ) |
| self._hooks.append( |
| self.target_layer.register_full_backward_hook(backward_hook) |
| ) |
| |
| def remove_hooks(self): |
| """Remove registered hooks.""" |
| for hook in self._hooks: |
| hook.remove() |
| self._hooks.clear() |
| |
| def __call__( |
| self, |
| input_tensor: torch.Tensor, |
| target_class: Optional[int] = None, |
| output_size: Tuple[int, int] = (224, 224) |
| ) -> np.ndarray: |
| """ |
| Compute Grad-CAM heatmap. |
| |
| Args: |
| input_tensor: Input image tensor [1, C, H, W] |
| target_class: Class index to compute gradients for. |
| If None, uses the predicted class. |
| output_size: Size to resize the heatmap to (H, W) |
| |
| Returns: |
| Normalized heatmap as numpy array [H, W] in range [0, 1] |
| """ |
| self.model.eval() |
| |
| |
| input_tensor = input_tensor.clone().requires_grad_(True) |
| |
| |
| output = self.model(input_tensor) |
| |
| |
| if isinstance(output, tuple): |
| logits = output[0] |
| else: |
| logits = output |
| |
| |
| if logits.dim() == 1: |
| logits = logits.unsqueeze(0) |
| |
| |
| if target_class is None: |
| target_class = logits.argmax(dim=1).item() |
| |
| |
| self.model.zero_grad() |
| |
| |
| if logits.shape[-1] > 1: |
| |
| target_score = logits[0, target_class] |
| else: |
| |
| target_score = logits[0, 0] if target_class == 1 else -logits[0, 0] |
| |
| target_score.backward(retain_graph=True) |
| |
| |
| if self.gradients is None or self.activations is None: |
| logger.warning("Gradients or activations not captured") |
| return np.zeros(output_size, dtype=np.float32) |
| |
| |
| weights = self.gradients.mean(dim=(2, 3), keepdim=True) |
| |
| |
| cam = (weights * self.activations).sum(dim=1, keepdim=True) |
| |
| |
| cam = F.relu(cam) |
| |
| |
| cam = cam - cam.min() |
| cam_max = cam.max() |
| if cam_max > 0: |
| cam = cam / cam_max |
| |
| |
| cam = F.interpolate( |
| cam, |
| size=output_size, |
| mode='bilinear', |
| align_corners=False |
| ) |
| |
| |
| heatmap = cam.squeeze().cpu().numpy() |
| |
| return heatmap |
| |
| def __del__(self): |
| self.remove_hooks() |
|
|
|
|
| def attention_rollout( |
| attentions: Union[List[torch.Tensor], torch.Tensor], |
| discard_ratio: float = 0.0, |
| head_fusion: str = "mean", |
| num_prefix_tokens: int = 1 |
| ) -> np.ndarray: |
| """ |
| Compute attention rollout for Vision Transformers. |
| |
| Aggregates attention across all layers by matrix multiplication, |
| accounting for residual connections. |
| |
| Args: |
| attentions: Attention tensors from each layer. Can be: |
| - List of tensors, each shape [batch, num_heads, seq_len, seq_len] or [seq_len, seq_len] |
| - Stacked tensor of shape [num_layers, seq_len, seq_len] (already head-fused) |
| discard_ratio: Fraction of lowest attention weights to discard |
| head_fusion: How to combine attention heads ("mean", "max", "min") |
| num_prefix_tokens: Number of special tokens (1 for ViT cls, 2 for DeiT cls+dist) |
| |
| Returns: |
| Attention map as numpy array of shape (grid_size, grid_size) |
| """ |
| |
| default_grid_size = 14 |
| |
| |
| if attentions is None: |
| logger.warning("No attention tensors provided (None)") |
| return np.zeros((default_grid_size, default_grid_size), dtype=np.float32) |
| |
| |
| if isinstance(attentions, torch.Tensor): |
| if attentions.numel() == 0: |
| logger.warning("Empty attention tensor provided") |
| return np.zeros((default_grid_size, default_grid_size), dtype=np.float32) |
| |
| attentions = [attentions[i] for i in range(attentions.shape[0])] |
| |
| |
| if len(attentions) == 0: |
| logger.warning("Empty attention list provided") |
| return np.zeros((default_grid_size, default_grid_size), dtype=np.float32) |
| |
| result = None |
| |
| for attention in attentions: |
| |
| if attention.dim() == 2: |
| |
| attention_fused = attention.unsqueeze(0) |
| elif attention.dim() == 3: |
| |
| attention_fused = attention |
| elif attention.dim() == 4: |
| |
| if head_fusion == "mean": |
| attention_fused = attention.mean(dim=1) |
| elif head_fusion == "max": |
| attention_fused = attention.max(dim=1)[0] |
| elif head_fusion == "min": |
| attention_fused = attention.min(dim=1)[0] |
| else: |
| attention_fused = attention.mean(dim=1) |
| else: |
| logger.warning(f"Unexpected attention shape: {attention.shape}") |
| continue |
| |
| |
| if discard_ratio > 0: |
| flat = attention_fused.view(attention_fused.size(0), -1) |
| threshold = torch.quantile(flat, discard_ratio, dim=1, keepdim=True) |
| threshold = threshold.view(attention_fused.size(0), 1, 1) |
| attention_fused = torch.where( |
| attention_fused < threshold, |
| torch.zeros_like(attention_fused), |
| attention_fused |
| ) |
| |
| attention_fused = attention_fused / (attention_fused.sum(dim=-1, keepdim=True) + 1e-9) |
| |
| |
| seq_len = attention_fused.size(-1) |
| identity = torch.eye(seq_len, device=attention_fused.device, dtype=attention_fused.dtype) |
| attention_with_residual = 0.5 * attention_fused + 0.5 * identity.unsqueeze(0) |
| |
| |
| if result is None: |
| result = attention_with_residual |
| else: |
| result = torch.bmm(attention_with_residual, result) |
| |
| |
| |
| |
| cls_attention = result[0, 0, num_prefix_tokens:] |
| |
| |
| num_patches = cls_attention.size(0) |
| grid_size = int(math.sqrt(num_patches)) |
| |
| if grid_size * grid_size != num_patches: |
| logger.warning(f"Non-square number of patches: {num_patches}") |
| |
| grid_size = int(math.ceil(math.sqrt(num_patches))) |
| padded = torch.zeros(grid_size * grid_size, device=cls_attention.device) |
| padded[:num_patches] = cls_attention |
| cls_attention = padded |
| |
| attention_map = cls_attention.reshape(grid_size, grid_size).cpu().numpy() |
| |
| |
| attention_map = attention_map - attention_map.min() |
| if attention_map.max() > 0: |
| attention_map = attention_map / attention_map.max() |
| |
| return attention_map |
|
|
|
|
| def heatmap_to_base64( |
| heatmap: np.ndarray, |
| colormap: str = "turbo", |
| output_size: Optional[Tuple[int, int]] = None |
| ) -> str: |
| """ |
| Convert a heatmap array to base64-encoded PNG string. |
| |
| Args: |
| heatmap: 2D numpy array with values in [0, 1] |
| colormap: Matplotlib colormap name ("turbo", "jet", "viridis", "inferno") |
| output_size: Optional (width, height) to resize to |
| |
| Returns: |
| Base64-encoded PNG string (without data:image/png;base64, prefix) |
| """ |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| import matplotlib.cm as cm |
| |
| |
| cmap = cm.get_cmap(colormap) |
| |
| |
| colored = cmap(heatmap) |
| |
| |
| rgb = (colored[:, :, :3] * 255).astype(np.uint8) |
| |
| |
| img = Image.fromarray(rgb) |
| |
| |
| if output_size is not None: |
| img = img.resize(output_size, Image.BILINEAR) |
| |
| |
| buffer = io.BytesIO() |
| img.save(buffer, format='PNG', optimize=True) |
| buffer.seek(0) |
| |
| |
| encoded = base64.b64encode(buffer.getvalue()).decode('utf-8') |
| |
| return encoded |
|
|
|
|
| def overlay_heatmap_on_image( |
| image: Union[np.ndarray, Image.Image], |
| heatmap: np.ndarray, |
| alpha: float = 0.5, |
| colormap: str = "turbo" |
| ) -> str: |
| """ |
| Overlay a heatmap on an image and return as base64 PNG. |
| |
| Args: |
| image: Original image (numpy array HWC or PIL Image) |
| heatmap: 2D heatmap array [0, 1] |
| alpha: Blend factor (0 = image only, 1 = heatmap only) |
| colormap: Matplotlib colormap name |
| |
| Returns: |
| Base64-encoded PNG of the overlaid image |
| """ |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.cm as cm |
| |
| |
| if isinstance(image, Image.Image): |
| image = np.array(image) |
| |
| |
| if image.dtype != np.uint8: |
| image = (image * 255).astype(np.uint8) |
| if image.ndim == 2: |
| image = np.stack([image] * 3, axis=-1) |
| elif image.shape[-1] == 1: |
| image = np.concatenate([image] * 3, axis=-1) |
| elif image.shape[-1] == 4: |
| image = image[:, :, :3] |
| |
| |
| h, w = image.shape[:2] |
| heatmap_resized = np.array( |
| Image.fromarray((heatmap * 255).astype(np.uint8)).resize((w, h), Image.BILINEAR) |
| ) / 255.0 |
| |
| |
| cmap = cm.get_cmap(colormap) |
| heatmap_colored = cmap(heatmap_resized)[:, :, :3] |
| heatmap_colored = (heatmap_colored * 255).astype(np.uint8) |
| |
| |
| blended = ( |
| (1 - alpha) * image.astype(np.float32) + |
| alpha * heatmap_colored.astype(np.float32) |
| ).astype(np.uint8) |
| |
| |
| img = Image.fromarray(blended) |
| buffer = io.BytesIO() |
| img.save(buffer, format='PNG', optimize=True) |
| buffer.seek(0) |
| |
| return base64.b64encode(buffer.getvalue()).decode('utf-8') |
|
|
|
|
| class AttentionExtractor: |
| """ |
| Hook-based attention extractor for ViT/DeiT models. |
| |
| Registers hooks on transformer blocks to capture attention weights |
| during forward pass. |
| |
| Usage: |
| extractor = AttentionExtractor(model.blocks) |
| output = model(input) |
| attentions = extractor.get_attentions() |
| extractor.clear() |
| """ |
| |
| def __init__(self, blocks: nn.ModuleList): |
| """ |
| Args: |
| blocks: List of transformer blocks (each should have .attn attribute) |
| """ |
| self.attentions: List[torch.Tensor] = [] |
| self._hooks: List = [] |
| |
| for block in blocks: |
| if hasattr(block, 'attn'): |
| |
| |
| |
| hook = block.attn.register_forward_hook(self._make_hook()) |
| self._hooks.append(hook) |
| |
| def _make_hook(self): |
| """Create a forward hook that captures attention weights.""" |
| def hook(module, input, output): |
| |
| |
| |
| pass |
| return hook |
| |
| def extract_attention_from_block( |
| self, |
| block: nn.Module, |
| x: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Extract attention weights from a single transformer block. |
| |
| Args: |
| block: Transformer block with attention module |
| x: Input tensor [B, seq_len, embed_dim] |
| |
| Returns: |
| Attention weights [B, num_heads, seq_len, seq_len] |
| """ |
| attn = block.attn |
| B, N, C = x.shape |
| |
| |
| qkv = attn.qkv(x).reshape(B, N, 3, attn.num_heads, C // attn.num_heads) |
| qkv = qkv.permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| |
| |
| scale = (C // attn.num_heads) ** -0.5 |
| attn_weights = (q @ k.transpose(-2, -1)) * scale |
| attn_weights = attn_weights.softmax(dim=-1) |
| |
| return attn_weights |
| |
| def get_attentions(self) -> List[torch.Tensor]: |
| """Return captured attention tensors.""" |
| return self.attentions |
| |
| def clear(self): |
| """Clear captured attentions.""" |
| self.attentions.clear() |
| |
| def remove_hooks(self): |
| """Remove all hooks.""" |
| for hook in self._hooks: |
| hook.remove() |
| self._hooks.clear() |
| |
| def __del__(self): |
| self.remove_hooks() |
|
|
|
|
| def compute_vit_attention_rollout( |
| model: nn.Module, |
| input_tensor: torch.Tensor, |
| blocks_attr: str = "blocks", |
| num_prefix_tokens: int = 1, |
| output_size: Tuple[int, int] = (224, 224) |
| ) -> np.ndarray: |
| """ |
| Compute attention rollout for a ViT-style model. |
| |
| Args: |
| model: The ViT model (should have .blocks attribute with transformer layers) |
| input_tensor: Input image tensor [1, 3, H, W] |
| blocks_attr: Attribute name for transformer blocks (e.g., "blocks" or "vit.blocks") |
| num_prefix_tokens: Number of prefix tokens (1 for CLS, 2 for CLS+DIST) |
| output_size: Size to resize output heatmap |
| |
| Returns: |
| Attention heatmap as numpy array [H, W] in range [0, 1] |
| """ |
| model.eval() |
| |
| |
| blocks = model |
| for attr in blocks_attr.split('.'): |
| blocks = getattr(blocks, attr) |
| |
| attentions = [] |
| |
| |
| def make_attn_hook(storage): |
| def hook(module, input, output): |
| |
| x = input[0] |
| B, N, C = x.shape |
| |
| |
| qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads) |
| qkv = qkv.permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| |
| |
| scale = (C // module.num_heads) ** -0.5 |
| attn = (q @ k.transpose(-2, -1)) * scale |
| attn = attn.softmax(dim=-1) |
| |
| storage.append(attn.detach()) |
| return hook |
| |
| |
| hooks = [] |
| for block in blocks: |
| if hasattr(block, 'attn'): |
| h = block.attn.register_forward_hook(make_attn_hook(attentions)) |
| hooks.append(h) |
| |
| try: |
| |
| with torch.no_grad(): |
| _ = model(input_tensor) |
| |
| |
| if attentions: |
| rollout = attention_rollout( |
| attentions, |
| num_prefix_tokens=num_prefix_tokens |
| ) |
| |
| |
| rollout_img = Image.fromarray((rollout * 255).astype(np.uint8)) |
| rollout_img = rollout_img.resize(output_size, Image.BILINEAR) |
| rollout = np.array(rollout_img) / 255.0 |
| |
| return rollout |
| else: |
| logger.warning("No attention weights captured") |
| return np.zeros(output_size, dtype=np.float32) |
| |
| finally: |
| |
| for h in hooks: |
| h.remove() |
|
|
|
|
| def compute_focus_summary( |
| heatmap: np.ndarray, |
| threshold: float = 0.5 |
| ) -> str: |
| """ |
| Compute a human-readable summary of where the heatmap focuses. |
| |
| Analyzes the heatmap to describe the spatial distribution of high |
| activation regions (e.g., "concentrated on upper-left", "diffuse across image"). |
| |
| Args: |
| heatmap: 2D numpy array with values in [0, 1], shape (H, W) |
| threshold: Threshold for considering a region as "high activation" |
| |
| Returns: |
| Human-readable focus summary string |
| """ |
| if heatmap is None or heatmap.size == 0: |
| return "no activation data available" |
| |
| |
| heatmap = np.array(heatmap, dtype=np.float32) |
| if heatmap.max() > 0: |
| heatmap = heatmap / heatmap.max() |
| |
| h, w = heatmap.shape |
| |
| |
| mask = heatmap > threshold |
| if not mask.any(): |
| |
| mask = heatmap > (heatmap.max() * 0.5) |
| |
| if not mask.any(): |
| return "very low activation across entire image" |
| |
| |
| y_coords, x_coords = np.where(mask) |
| |
| |
| centroid_y = y_coords.mean() / h |
| centroid_x = x_coords.mean() / w |
| |
| |
| spread_y = y_coords.std() / h if len(y_coords) > 1 else 0 |
| spread_x = x_coords.std() / w if len(x_coords) > 1 else 0 |
| spread = (spread_y + spread_x) / 2 |
| |
| |
| coverage = mask.sum() / mask.size |
| |
| |
| parts = [] |
| |
| |
| if spread < 0.15: |
| parts.append("highly concentrated") |
| elif spread < 0.25: |
| parts.append("moderately concentrated") |
| else: |
| parts.append("spread across") |
| |
| |
| location_parts = [] |
| |
| |
| if centroid_y < 0.33: |
| location_parts.append("upper") |
| elif centroid_y > 0.67: |
| location_parts.append("lower") |
| else: |
| location_parts.append("middle") |
| |
| |
| if centroid_x < 0.33: |
| location_parts.append("left") |
| elif centroid_x > 0.67: |
| location_parts.append("right") |
| else: |
| location_parts.append("center") |
| |
| |
| if location_parts == ["middle", "center"]: |
| location = "central region" |
| else: |
| location = "-".join(location_parts) + " region" |
| |
| parts.append(location) |
| |
| |
| if coverage > 0.4: |
| parts.append(f"(~{int(coverage*100)}% of image)") |
| |
| summary = " ".join(parts) |
| |
| |
| |
| if centroid_y < 0.5 and 0.3 < centroid_x < 0.7 and spread < 0.2: |
| summary += " (likely face/subject area)" |
| elif spread > 0.3: |
| summary += " (examining multiple regions)" |
| |
| return summary |
|
|
|
|