""" Explainability utilities for DeepFake detection models. Provides: - GradCAM: For CNN-based models (EfficientNet, CompactGradientNet) - Attention Rollout: For ViT/DeiT transformer models - Heatmap visualization utilities """ import base64 import io import math from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from app.core.logging import get_logger logger = get_logger(__name__) class GradCAM: """ Gradient-weighted Class Activation Mapping for CNN models. Computes importance heatmaps by weighting feature map activations by the gradients flowing into them from the target class. Usage: gradcam = GradCAM(model, target_layer) heatmap = gradcam(input_tensor, target_class=1) """ def __init__(self, model: nn.Module, target_layer: nn.Module): """ Args: model: The CNN model target_layer: The convolutional layer to compute Grad-CAM on (typically the last conv layer before pooling) """ self.model = model self.target_layer = target_layer self.gradients: Optional[torch.Tensor] = None self.activations: Optional[torch.Tensor] = None self._hooks: List = [] self._register_hooks() def _register_hooks(self): """Register forward and backward hooks on target layer.""" def forward_hook(module, input, output): self.activations = output.detach() def backward_hook(module, grad_input, grad_output): self.gradients = grad_output[0].detach() self._hooks.append( self.target_layer.register_forward_hook(forward_hook) ) self._hooks.append( self.target_layer.register_full_backward_hook(backward_hook) ) def remove_hooks(self): """Remove registered hooks.""" for hook in self._hooks: hook.remove() self._hooks.clear() def __call__( self, input_tensor: torch.Tensor, target_class: Optional[int] = None, output_size: Tuple[int, int] = (224, 224) ) -> np.ndarray: """ Compute Grad-CAM heatmap. Args: input_tensor: Input image tensor [1, C, H, W] target_class: Class index to compute gradients for. If None, uses the predicted class. output_size: Size to resize the heatmap to (H, W) Returns: Normalized heatmap as numpy array [H, W] in range [0, 1] """ self.model.eval() # Enable gradients for this forward pass input_tensor = input_tensor.clone().requires_grad_(True) # Forward pass output = self.model(input_tensor) # Handle different output formats if isinstance(output, tuple): logits = output[0] # Some models return (logits, embeddings) else: logits = output # Ensure logits is 2D [batch, classes] if logits.dim() == 1: logits = logits.unsqueeze(0) # Determine target class if target_class is None: target_class = logits.argmax(dim=1).item() # Zero gradients self.model.zero_grad() # Backward pass for target class if logits.shape[-1] > 1: # Multi-class: select target class score target_score = logits[0, target_class] else: # Binary with single output: use the logit directly target_score = logits[0, 0] if target_class == 1 else -logits[0, 0] target_score.backward(retain_graph=True) # Compute Grad-CAM if self.gradients is None or self.activations is None: logger.warning("Gradients or activations not captured") return np.zeros(output_size, dtype=np.float32) # Global average pool gradients to get weights weights = self.gradients.mean(dim=(2, 3), keepdim=True) # [1, C, 1, 1] # Weighted combination of activation maps cam = (weights * self.activations).sum(dim=1, keepdim=True) # [1, 1, H, W] # ReLU to keep only positive contributions cam = F.relu(cam) # Normalize cam = cam - cam.min() cam_max = cam.max() if cam_max > 0: cam = cam / cam_max # Resize to output size cam = F.interpolate( cam, size=output_size, mode='bilinear', align_corners=False ) # Convert to numpy heatmap = cam.squeeze().cpu().numpy() return heatmap def __del__(self): self.remove_hooks() def attention_rollout( attentions: Union[List[torch.Tensor], torch.Tensor], discard_ratio: float = 0.0, head_fusion: str = "mean", num_prefix_tokens: int = 1 ) -> np.ndarray: """ Compute attention rollout for Vision Transformers. Aggregates attention across all layers by matrix multiplication, accounting for residual connections. Args: attentions: Attention tensors from each layer. Can be: - List of tensors, each shape [batch, num_heads, seq_len, seq_len] or [seq_len, seq_len] - Stacked tensor of shape [num_layers, seq_len, seq_len] (already head-fused) discard_ratio: Fraction of lowest attention weights to discard head_fusion: How to combine attention heads ("mean", "max", "min") num_prefix_tokens: Number of special tokens (1 for ViT cls, 2 for DeiT cls+dist) Returns: Attention map as numpy array of shape (grid_size, grid_size) """ # Default grid size for ViT-Base (14x14 patches from 224x224 with 16x16 patch size) default_grid_size = 14 # Handle empty input if attentions is None: logger.warning("No attention tensors provided (None)") return np.zeros((default_grid_size, default_grid_size), dtype=np.float32) # Convert tensor to list if needed if isinstance(attentions, torch.Tensor): if attentions.numel() == 0: logger.warning("Empty attention tensor provided") return np.zeros((default_grid_size, default_grid_size), dtype=np.float32) # Convert stacked tensor to list attentions = [attentions[i] for i in range(attentions.shape[0])] # Check if list is empty if len(attentions) == 0: logger.warning("Empty attention list provided") return np.zeros((default_grid_size, default_grid_size), dtype=np.float32) result = None for attention in attentions: # Handle different input formats if attention.dim() == 2: # Already fused: [seq_len, seq_len] attention_fused = attention.unsqueeze(0) # [1, seq, seq] elif attention.dim() == 3: # Batched without heads or already fused: [B, seq, seq] attention_fused = attention elif attention.dim() == 4: # Full attention: [B, heads, seq, seq] - fuse heads if head_fusion == "mean": attention_fused = attention.mean(dim=1) # [B, seq, seq] elif head_fusion == "max": attention_fused = attention.max(dim=1)[0] elif head_fusion == "min": attention_fused = attention.min(dim=1)[0] else: attention_fused = attention.mean(dim=1) else: logger.warning(f"Unexpected attention shape: {attention.shape}") continue # Discard low attention (optional) if discard_ratio > 0: flat = attention_fused.view(attention_fused.size(0), -1) threshold = torch.quantile(flat, discard_ratio, dim=1, keepdim=True) threshold = threshold.view(attention_fused.size(0), 1, 1) attention_fused = torch.where( attention_fused < threshold, torch.zeros_like(attention_fused), attention_fused ) # Renormalize attention_fused = attention_fused / (attention_fused.sum(dim=-1, keepdim=True) + 1e-9) # Add identity for residual connection seq_len = attention_fused.size(-1) identity = torch.eye(seq_len, device=attention_fused.device, dtype=attention_fused.dtype) attention_with_residual = 0.5 * attention_fused + 0.5 * identity.unsqueeze(0) # Matrix multiply through layers if result is None: result = attention_with_residual else: result = torch.bmm(attention_with_residual, result) # Extract CLS token attention to all patch tokens # result shape: [B, seq_len, seq_len] # CLS token is at index 0, patches start at index num_prefix_tokens cls_attention = result[0, 0, num_prefix_tokens:] # [num_patches] # Reshape to grid num_patches = cls_attention.size(0) grid_size = int(math.sqrt(num_patches)) if grid_size * grid_size != num_patches: logger.warning(f"Non-square number of patches: {num_patches}") # Pad or truncate to nearest square grid_size = int(math.ceil(math.sqrt(num_patches))) padded = torch.zeros(grid_size * grid_size, device=cls_attention.device) padded[:num_patches] = cls_attention cls_attention = padded attention_map = cls_attention.reshape(grid_size, grid_size).cpu().numpy() # Normalize attention_map = attention_map - attention_map.min() if attention_map.max() > 0: attention_map = attention_map / attention_map.max() return attention_map def heatmap_to_base64( heatmap: np.ndarray, colormap: str = "turbo", output_size: Optional[Tuple[int, int]] = None ) -> str: """ Convert a heatmap array to base64-encoded PNG string. Args: heatmap: 2D numpy array with values in [0, 1] colormap: Matplotlib colormap name ("turbo", "jet", "viridis", "inferno") output_size: Optional (width, height) to resize to Returns: Base64-encoded PNG string (without data:image/png;base64, prefix) """ import matplotlib matplotlib.use('Agg') # Non-interactive backend import matplotlib.pyplot as plt import matplotlib.cm as cm # Get colormap cmap = cm.get_cmap(colormap) # Apply colormap (returns RGBA) colored = cmap(heatmap) # Convert to uint8 RGB rgb = (colored[:, :, :3] * 255).astype(np.uint8) # Create PIL image img = Image.fromarray(rgb) # Resize if needed if output_size is not None: img = img.resize(output_size, Image.BILINEAR) # Save to bytes buffer = io.BytesIO() img.save(buffer, format='PNG', optimize=True) buffer.seek(0) # Encode to base64 encoded = base64.b64encode(buffer.getvalue()).decode('utf-8') return encoded def overlay_heatmap_on_image( image: Union[np.ndarray, Image.Image], heatmap: np.ndarray, alpha: float = 0.5, colormap: str = "turbo" ) -> str: """ Overlay a heatmap on an image and return as base64 PNG. Args: image: Original image (numpy array HWC or PIL Image) heatmap: 2D heatmap array [0, 1] alpha: Blend factor (0 = image only, 1 = heatmap only) colormap: Matplotlib colormap name Returns: Base64-encoded PNG of the overlaid image """ import matplotlib matplotlib.use('Agg') import matplotlib.cm as cm # Convert image to numpy if needed if isinstance(image, Image.Image): image = np.array(image) # Ensure image is uint8 RGB if image.dtype != np.uint8: image = (image * 255).astype(np.uint8) if image.ndim == 2: image = np.stack([image] * 3, axis=-1) elif image.shape[-1] == 1: image = np.concatenate([image] * 3, axis=-1) elif image.shape[-1] == 4: image = image[:, :, :3] # Resize heatmap to match image size h, w = image.shape[:2] heatmap_resized = np.array( Image.fromarray((heatmap * 255).astype(np.uint8)).resize((w, h), Image.BILINEAR) ) / 255.0 # Apply colormap cmap = cm.get_cmap(colormap) heatmap_colored = cmap(heatmap_resized)[:, :, :3] heatmap_colored = (heatmap_colored * 255).astype(np.uint8) # Blend blended = ( (1 - alpha) * image.astype(np.float32) + alpha * heatmap_colored.astype(np.float32) ).astype(np.uint8) # Convert to base64 img = Image.fromarray(blended) buffer = io.BytesIO() img.save(buffer, format='PNG', optimize=True) buffer.seek(0) return base64.b64encode(buffer.getvalue()).decode('utf-8') class AttentionExtractor: """ Hook-based attention extractor for ViT/DeiT models. Registers hooks on transformer blocks to capture attention weights during forward pass. Usage: extractor = AttentionExtractor(model.blocks) output = model(input) attentions = extractor.get_attentions() extractor.clear() """ def __init__(self, blocks: nn.ModuleList): """ Args: blocks: List of transformer blocks (each should have .attn attribute) """ self.attentions: List[torch.Tensor] = [] self._hooks: List = [] for block in blocks: if hasattr(block, 'attn'): # Hook into the attention module # We need to capture after softmax, before dropout # timm stores attention in attn.attn_drop or we can compute from qkv hook = block.attn.register_forward_hook(self._make_hook()) self._hooks.append(hook) def _make_hook(self): """Create a forward hook that captures attention weights.""" def hook(module, input, output): # For timm ViT, we need to recompute attention from qkv # The module receives x and outputs x after attention # We'll store a flag and compute in get_attentions pass return hook def extract_attention_from_block( self, block: nn.Module, x: torch.Tensor ) -> torch.Tensor: """ Extract attention weights from a single transformer block. Args: block: Transformer block with attention module x: Input tensor [B, seq_len, embed_dim] Returns: Attention weights [B, num_heads, seq_len, seq_len] """ attn = block.attn B, N, C = x.shape # Get qkv qkv = attn.qkv(x).reshape(B, N, 3, attn.num_heads, C // attn.num_heads) qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, N, head_dim] q, k, v = qkv[0], qkv[1], qkv[2] # Compute attention scale = (C // attn.num_heads) ** -0.5 attn_weights = (q @ k.transpose(-2, -1)) * scale attn_weights = attn_weights.softmax(dim=-1) return attn_weights def get_attentions(self) -> List[torch.Tensor]: """Return captured attention tensors.""" return self.attentions def clear(self): """Clear captured attentions.""" self.attentions.clear() def remove_hooks(self): """Remove all hooks.""" for hook in self._hooks: hook.remove() self._hooks.clear() def __del__(self): self.remove_hooks() def compute_vit_attention_rollout( model: nn.Module, input_tensor: torch.Tensor, blocks_attr: str = "blocks", num_prefix_tokens: int = 1, output_size: Tuple[int, int] = (224, 224) ) -> np.ndarray: """ Compute attention rollout for a ViT-style model. Args: model: The ViT model (should have .blocks attribute with transformer layers) input_tensor: Input image tensor [1, 3, H, W] blocks_attr: Attribute name for transformer blocks (e.g., "blocks" or "vit.blocks") num_prefix_tokens: Number of prefix tokens (1 for CLS, 2 for CLS+DIST) output_size: Size to resize output heatmap Returns: Attention heatmap as numpy array [H, W] in range [0, 1] """ model.eval() # Navigate to blocks blocks = model for attr in blocks_attr.split('.'): blocks = getattr(blocks, attr) attentions = [] # Hook to capture attention weights def make_attn_hook(storage): def hook(module, input, output): # Recompute attention weights x = input[0] B, N, C = x.shape # Get qkv projection qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # Compute attention weights scale = (C // module.num_heads) ** -0.5 attn = (q @ k.transpose(-2, -1)) * scale attn = attn.softmax(dim=-1) storage.append(attn.detach()) return hook # Register hooks hooks = [] for block in blocks: if hasattr(block, 'attn'): h = block.attn.register_forward_hook(make_attn_hook(attentions)) hooks.append(h) try: # Forward pass with torch.no_grad(): _ = model(input_tensor) # Compute rollout if attentions: rollout = attention_rollout( attentions, num_prefix_tokens=num_prefix_tokens ) # Resize to output size rollout_img = Image.fromarray((rollout * 255).astype(np.uint8)) rollout_img = rollout_img.resize(output_size, Image.BILINEAR) rollout = np.array(rollout_img) / 255.0 return rollout else: logger.warning("No attention weights captured") return np.zeros(output_size, dtype=np.float32) finally: # Clean up hooks for h in hooks: h.remove() def compute_focus_summary( heatmap: np.ndarray, threshold: float = 0.5 ) -> str: """ Compute a human-readable summary of where the heatmap focuses. Analyzes the heatmap to describe the spatial distribution of high activation regions (e.g., "concentrated on upper-left", "diffuse across image"). Args: heatmap: 2D numpy array with values in [0, 1], shape (H, W) threshold: Threshold for considering a region as "high activation" Returns: Human-readable focus summary string """ if heatmap is None or heatmap.size == 0: return "no activation data available" # Normalize heatmap heatmap = np.array(heatmap, dtype=np.float32) if heatmap.max() > 0: heatmap = heatmap / heatmap.max() h, w = heatmap.shape # Compute centroid of high activation regions mask = heatmap > threshold if not mask.any(): # Lower threshold if nothing above it mask = heatmap > (heatmap.max() * 0.5) if not mask.any(): return "very low activation across entire image" # Get coordinates of activated pixels y_coords, x_coords = np.where(mask) # Compute centroid centroid_y = y_coords.mean() / h # Normalized to [0, 1] centroid_x = x_coords.mean() / w # Normalized to [0, 1] # Compute spread (standard deviation normalized by image size) spread_y = y_coords.std() / h if len(y_coords) > 1 else 0 spread_x = x_coords.std() / w if len(x_coords) > 1 else 0 spread = (spread_y + spread_x) / 2 # Compute coverage (fraction of image with high activation) coverage = mask.sum() / mask.size # Build description parts = [] # Describe spread if spread < 0.15: parts.append("highly concentrated") elif spread < 0.25: parts.append("moderately concentrated") else: parts.append("spread across") # Describe location location_parts = [] # Vertical position if centroid_y < 0.33: location_parts.append("upper") elif centroid_y > 0.67: location_parts.append("lower") else: location_parts.append("middle") # Horizontal position if centroid_x < 0.33: location_parts.append("left") elif centroid_x > 0.67: location_parts.append("right") else: location_parts.append("center") # Combine location (avoid "middle center") if location_parts == ["middle", "center"]: location = "central region" else: location = "-".join(location_parts) + " region" parts.append(location) # Add coverage note for diffuse patterns if coverage > 0.4: parts.append(f"(~{int(coverage*100)}% of image)") summary = " ".join(parts) # Add semantic hints based on common portrait regions # Center typically = face, edges/corners = background if centroid_y < 0.5 and 0.3 < centroid_x < 0.7 and spread < 0.2: summary += " (likely face/subject area)" elif spread > 0.3: summary += " (examining multiple regions)" return summary