| | """ |
| | Explainability utilities for DeepFake detection models. |
| | |
| | Provides: |
| | - GradCAM: For CNN-based models (EfficientNet, CompactGradientNet) |
| | - Attention Rollout: For ViT/DeiT transformer models |
| | - Heatmap visualization utilities |
| | """ |
| |
|
| | import base64 |
| | import io |
| | import math |
| | from typing import List, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from PIL import Image |
| |
|
| | from app.core.logging import get_logger |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | class GradCAM: |
| | """ |
| | Gradient-weighted Class Activation Mapping for CNN models. |
| | |
| | Computes importance heatmaps by weighting feature map activations |
| | by the gradients flowing into them from the target class. |
| | |
| | Usage: |
| | gradcam = GradCAM(model, target_layer) |
| | heatmap = gradcam(input_tensor, target_class=1) |
| | """ |
| | |
| | def __init__(self, model: nn.Module, target_layer: nn.Module): |
| | """ |
| | Args: |
| | model: The CNN model |
| | target_layer: The convolutional layer to compute Grad-CAM on |
| | (typically the last conv layer before pooling) |
| | """ |
| | self.model = model |
| | self.target_layer = target_layer |
| | self.gradients: Optional[torch.Tensor] = None |
| | self.activations: Optional[torch.Tensor] = None |
| | self._hooks: List = [] |
| | |
| | self._register_hooks() |
| | |
| | def _register_hooks(self): |
| | """Register forward and backward hooks on target layer.""" |
| | def forward_hook(module, input, output): |
| | self.activations = output.detach() |
| | |
| | def backward_hook(module, grad_input, grad_output): |
| | self.gradients = grad_output[0].detach() |
| | |
| | self._hooks.append( |
| | self.target_layer.register_forward_hook(forward_hook) |
| | ) |
| | self._hooks.append( |
| | self.target_layer.register_full_backward_hook(backward_hook) |
| | ) |
| | |
| | def remove_hooks(self): |
| | """Remove registered hooks.""" |
| | for hook in self._hooks: |
| | hook.remove() |
| | self._hooks.clear() |
| | |
| | def __call__( |
| | self, |
| | input_tensor: torch.Tensor, |
| | target_class: Optional[int] = None, |
| | output_size: Tuple[int, int] = (224, 224) |
| | ) -> np.ndarray: |
| | """ |
| | Compute Grad-CAM heatmap. |
| | |
| | Args: |
| | input_tensor: Input image tensor [1, C, H, W] |
| | target_class: Class index to compute gradients for. |
| | If None, uses the predicted class. |
| | output_size: Size to resize the heatmap to (H, W) |
| | |
| | Returns: |
| | Normalized heatmap as numpy array [H, W] in range [0, 1] |
| | """ |
| | self.model.eval() |
| | |
| | |
| | input_tensor = input_tensor.clone().requires_grad_(True) |
| | |
| | |
| | output = self.model(input_tensor) |
| | |
| | |
| | if isinstance(output, tuple): |
| | logits = output[0] |
| | else: |
| | logits = output |
| | |
| | |
| | if logits.dim() == 1: |
| | logits = logits.unsqueeze(0) |
| | |
| | |
| | if target_class is None: |
| | target_class = logits.argmax(dim=1).item() |
| | |
| | |
| | self.model.zero_grad() |
| | |
| | |
| | if logits.shape[-1] > 1: |
| | |
| | target_score = logits[0, target_class] |
| | else: |
| | |
| | target_score = logits[0, 0] if target_class == 1 else -logits[0, 0] |
| | |
| | target_score.backward(retain_graph=True) |
| | |
| | |
| | if self.gradients is None or self.activations is None: |
| | logger.warning("Gradients or activations not captured") |
| | return np.zeros(output_size, dtype=np.float32) |
| | |
| | |
| | weights = self.gradients.mean(dim=(2, 3), keepdim=True) |
| | |
| | |
| | cam = (weights * self.activations).sum(dim=1, keepdim=True) |
| | |
| | |
| | cam = F.relu(cam) |
| | |
| | |
| | cam = cam - cam.min() |
| | cam_max = cam.max() |
| | if cam_max > 0: |
| | cam = cam / cam_max |
| | |
| | |
| | cam = F.interpolate( |
| | cam, |
| | size=output_size, |
| | mode='bilinear', |
| | align_corners=False |
| | ) |
| | |
| | |
| | heatmap = cam.squeeze().cpu().numpy() |
| | |
| | return heatmap |
| | |
| | def __del__(self): |
| | self.remove_hooks() |
| |
|
| |
|
| | def attention_rollout( |
| | attentions: Union[List[torch.Tensor], torch.Tensor], |
| | discard_ratio: float = 0.0, |
| | head_fusion: str = "mean", |
| | num_prefix_tokens: int = 1 |
| | ) -> np.ndarray: |
| | """ |
| | Compute attention rollout for Vision Transformers. |
| | |
| | Aggregates attention across all layers by matrix multiplication, |
| | accounting for residual connections. |
| | |
| | Args: |
| | attentions: Attention tensors from each layer. Can be: |
| | - List of tensors, each shape [batch, num_heads, seq_len, seq_len] or [seq_len, seq_len] |
| | - Stacked tensor of shape [num_layers, seq_len, seq_len] (already head-fused) |
| | discard_ratio: Fraction of lowest attention weights to discard |
| | head_fusion: How to combine attention heads ("mean", "max", "min") |
| | num_prefix_tokens: Number of special tokens (1 for ViT cls, 2 for DeiT cls+dist) |
| | |
| | Returns: |
| | Attention map as numpy array of shape (grid_size, grid_size) |
| | """ |
| | |
| | default_grid_size = 14 |
| | |
| | |
| | if attentions is None: |
| | logger.warning("No attention tensors provided (None)") |
| | return np.zeros((default_grid_size, default_grid_size), dtype=np.float32) |
| | |
| | |
| | if isinstance(attentions, torch.Tensor): |
| | if attentions.numel() == 0: |
| | logger.warning("Empty attention tensor provided") |
| | return np.zeros((default_grid_size, default_grid_size), dtype=np.float32) |
| | |
| | attentions = [attentions[i] for i in range(attentions.shape[0])] |
| | |
| | |
| | if len(attentions) == 0: |
| | logger.warning("Empty attention list provided") |
| | return np.zeros((default_grid_size, default_grid_size), dtype=np.float32) |
| | |
| | result = None |
| | |
| | for attention in attentions: |
| | |
| | if attention.dim() == 2: |
| | |
| | attention_fused = attention.unsqueeze(0) |
| | elif attention.dim() == 3: |
| | |
| | attention_fused = attention |
| | elif attention.dim() == 4: |
| | |
| | if head_fusion == "mean": |
| | attention_fused = attention.mean(dim=1) |
| | elif head_fusion == "max": |
| | attention_fused = attention.max(dim=1)[0] |
| | elif head_fusion == "min": |
| | attention_fused = attention.min(dim=1)[0] |
| | else: |
| | attention_fused = attention.mean(dim=1) |
| | else: |
| | logger.warning(f"Unexpected attention shape: {attention.shape}") |
| | continue |
| | |
| | |
| | if discard_ratio > 0: |
| | flat = attention_fused.view(attention_fused.size(0), -1) |
| | threshold = torch.quantile(flat, discard_ratio, dim=1, keepdim=True) |
| | threshold = threshold.view(attention_fused.size(0), 1, 1) |
| | attention_fused = torch.where( |
| | attention_fused < threshold, |
| | torch.zeros_like(attention_fused), |
| | attention_fused |
| | ) |
| | |
| | attention_fused = attention_fused / (attention_fused.sum(dim=-1, keepdim=True) + 1e-9) |
| | |
| | |
| | seq_len = attention_fused.size(-1) |
| | identity = torch.eye(seq_len, device=attention_fused.device, dtype=attention_fused.dtype) |
| | attention_with_residual = 0.5 * attention_fused + 0.5 * identity.unsqueeze(0) |
| | |
| | |
| | if result is None: |
| | result = attention_with_residual |
| | else: |
| | result = torch.bmm(attention_with_residual, result) |
| | |
| | |
| | |
| | |
| | cls_attention = result[0, 0, num_prefix_tokens:] |
| | |
| | |
| | num_patches = cls_attention.size(0) |
| | grid_size = int(math.sqrt(num_patches)) |
| | |
| | if grid_size * grid_size != num_patches: |
| | logger.warning(f"Non-square number of patches: {num_patches}") |
| | |
| | grid_size = int(math.ceil(math.sqrt(num_patches))) |
| | padded = torch.zeros(grid_size * grid_size, device=cls_attention.device) |
| | padded[:num_patches] = cls_attention |
| | cls_attention = padded |
| | |
| | attention_map = cls_attention.reshape(grid_size, grid_size).cpu().numpy() |
| | |
| | |
| | attention_map = attention_map - attention_map.min() |
| | if attention_map.max() > 0: |
| | attention_map = attention_map / attention_map.max() |
| | |
| | return attention_map |
| |
|
| |
|
| | def heatmap_to_base64( |
| | heatmap: np.ndarray, |
| | colormap: str = "turbo", |
| | output_size: Optional[Tuple[int, int]] = None |
| | ) -> str: |
| | """ |
| | Convert a heatmap array to base64-encoded PNG string. |
| | |
| | Args: |
| | heatmap: 2D numpy array with values in [0, 1] |
| | colormap: Matplotlib colormap name ("turbo", "jet", "viridis", "inferno") |
| | output_size: Optional (width, height) to resize to |
| | |
| | Returns: |
| | Base64-encoded PNG string (without data:image/png;base64, prefix) |
| | """ |
| | import matplotlib |
| | matplotlib.use('Agg') |
| | import matplotlib.pyplot as plt |
| | import matplotlib.cm as cm |
| | |
| | |
| | cmap = cm.get_cmap(colormap) |
| | |
| | |
| | colored = cmap(heatmap) |
| | |
| | |
| | rgb = (colored[:, :, :3] * 255).astype(np.uint8) |
| | |
| | |
| | img = Image.fromarray(rgb) |
| | |
| | |
| | if output_size is not None: |
| | img = img.resize(output_size, Image.BILINEAR) |
| | |
| | |
| | buffer = io.BytesIO() |
| | img.save(buffer, format='PNG', optimize=True) |
| | buffer.seek(0) |
| | |
| | |
| | encoded = base64.b64encode(buffer.getvalue()).decode('utf-8') |
| | |
| | return encoded |
| |
|
| |
|
| | def overlay_heatmap_on_image( |
| | image: Union[np.ndarray, Image.Image], |
| | heatmap: np.ndarray, |
| | alpha: float = 0.5, |
| | colormap: str = "turbo" |
| | ) -> str: |
| | """ |
| | Overlay a heatmap on an image and return as base64 PNG. |
| | |
| | Args: |
| | image: Original image (numpy array HWC or PIL Image) |
| | heatmap: 2D heatmap array [0, 1] |
| | alpha: Blend factor (0 = image only, 1 = heatmap only) |
| | colormap: Matplotlib colormap name |
| | |
| | Returns: |
| | Base64-encoded PNG of the overlaid image |
| | """ |
| | import matplotlib |
| | matplotlib.use('Agg') |
| | import matplotlib.cm as cm |
| | |
| | |
| | if isinstance(image, Image.Image): |
| | image = np.array(image) |
| | |
| | |
| | if image.dtype != np.uint8: |
| | image = (image * 255).astype(np.uint8) |
| | if image.ndim == 2: |
| | image = np.stack([image] * 3, axis=-1) |
| | elif image.shape[-1] == 1: |
| | image = np.concatenate([image] * 3, axis=-1) |
| | elif image.shape[-1] == 4: |
| | image = image[:, :, :3] |
| | |
| | |
| | h, w = image.shape[:2] |
| | heatmap_resized = np.array( |
| | Image.fromarray((heatmap * 255).astype(np.uint8)).resize((w, h), Image.BILINEAR) |
| | ) / 255.0 |
| | |
| | |
| | cmap = cm.get_cmap(colormap) |
| | heatmap_colored = cmap(heatmap_resized)[:, :, :3] |
| | heatmap_colored = (heatmap_colored * 255).astype(np.uint8) |
| | |
| | |
| | blended = ( |
| | (1 - alpha) * image.astype(np.float32) + |
| | alpha * heatmap_colored.astype(np.float32) |
| | ).astype(np.uint8) |
| | |
| | |
| | img = Image.fromarray(blended) |
| | buffer = io.BytesIO() |
| | img.save(buffer, format='PNG', optimize=True) |
| | buffer.seek(0) |
| | |
| | return base64.b64encode(buffer.getvalue()).decode('utf-8') |
| |
|
| |
|
| | class AttentionExtractor: |
| | """ |
| | Hook-based attention extractor for ViT/DeiT models. |
| | |
| | Registers hooks on transformer blocks to capture attention weights |
| | during forward pass. |
| | |
| | Usage: |
| | extractor = AttentionExtractor(model.blocks) |
| | output = model(input) |
| | attentions = extractor.get_attentions() |
| | extractor.clear() |
| | """ |
| | |
| | def __init__(self, blocks: nn.ModuleList): |
| | """ |
| | Args: |
| | blocks: List of transformer blocks (each should have .attn attribute) |
| | """ |
| | self.attentions: List[torch.Tensor] = [] |
| | self._hooks: List = [] |
| | |
| | for block in blocks: |
| | if hasattr(block, 'attn'): |
| | |
| | |
| | |
| | hook = block.attn.register_forward_hook(self._make_hook()) |
| | self._hooks.append(hook) |
| | |
| | def _make_hook(self): |
| | """Create a forward hook that captures attention weights.""" |
| | def hook(module, input, output): |
| | |
| | |
| | |
| | pass |
| | return hook |
| | |
| | def extract_attention_from_block( |
| | self, |
| | block: nn.Module, |
| | x: torch.Tensor |
| | ) -> torch.Tensor: |
| | """ |
| | Extract attention weights from a single transformer block. |
| | |
| | Args: |
| | block: Transformer block with attention module |
| | x: Input tensor [B, seq_len, embed_dim] |
| | |
| | Returns: |
| | Attention weights [B, num_heads, seq_len, seq_len] |
| | """ |
| | attn = block.attn |
| | B, N, C = x.shape |
| | |
| | |
| | qkv = attn.qkv(x).reshape(B, N, 3, attn.num_heads, C // attn.num_heads) |
| | qkv = qkv.permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| | |
| | |
| | scale = (C // attn.num_heads) ** -0.5 |
| | attn_weights = (q @ k.transpose(-2, -1)) * scale |
| | attn_weights = attn_weights.softmax(dim=-1) |
| | |
| | return attn_weights |
| | |
| | def get_attentions(self) -> List[torch.Tensor]: |
| | """Return captured attention tensors.""" |
| | return self.attentions |
| | |
| | def clear(self): |
| | """Clear captured attentions.""" |
| | self.attentions.clear() |
| | |
| | def remove_hooks(self): |
| | """Remove all hooks.""" |
| | for hook in self._hooks: |
| | hook.remove() |
| | self._hooks.clear() |
| | |
| | def __del__(self): |
| | self.remove_hooks() |
| |
|
| |
|
| | def compute_vit_attention_rollout( |
| | model: nn.Module, |
| | input_tensor: torch.Tensor, |
| | blocks_attr: str = "blocks", |
| | num_prefix_tokens: int = 1, |
| | output_size: Tuple[int, int] = (224, 224) |
| | ) -> np.ndarray: |
| | """ |
| | Compute attention rollout for a ViT-style model. |
| | |
| | Args: |
| | model: The ViT model (should have .blocks attribute with transformer layers) |
| | input_tensor: Input image tensor [1, 3, H, W] |
| | blocks_attr: Attribute name for transformer blocks (e.g., "blocks" or "vit.blocks") |
| | num_prefix_tokens: Number of prefix tokens (1 for CLS, 2 for CLS+DIST) |
| | output_size: Size to resize output heatmap |
| | |
| | Returns: |
| | Attention heatmap as numpy array [H, W] in range [0, 1] |
| | """ |
| | model.eval() |
| | |
| | |
| | blocks = model |
| | for attr in blocks_attr.split('.'): |
| | blocks = getattr(blocks, attr) |
| | |
| | attentions = [] |
| | |
| | |
| | def make_attn_hook(storage): |
| | def hook(module, input, output): |
| | |
| | x = input[0] |
| | B, N, C = x.shape |
| | |
| | |
| | qkv = module.qkv(x).reshape(B, N, 3, module.num_heads, C // module.num_heads) |
| | qkv = qkv.permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| | |
| | |
| | scale = (C // module.num_heads) ** -0.5 |
| | attn = (q @ k.transpose(-2, -1)) * scale |
| | attn = attn.softmax(dim=-1) |
| | |
| | storage.append(attn.detach()) |
| | return hook |
| | |
| | |
| | hooks = [] |
| | for block in blocks: |
| | if hasattr(block, 'attn'): |
| | h = block.attn.register_forward_hook(make_attn_hook(attentions)) |
| | hooks.append(h) |
| | |
| | try: |
| | |
| | with torch.no_grad(): |
| | _ = model(input_tensor) |
| | |
| | |
| | if attentions: |
| | rollout = attention_rollout( |
| | attentions, |
| | num_prefix_tokens=num_prefix_tokens |
| | ) |
| | |
| | |
| | rollout_img = Image.fromarray((rollout * 255).astype(np.uint8)) |
| | rollout_img = rollout_img.resize(output_size, Image.BILINEAR) |
| | rollout = np.array(rollout_img) / 255.0 |
| | |
| | return rollout |
| | else: |
| | logger.warning("No attention weights captured") |
| | return np.zeros(output_size, dtype=np.float32) |
| | |
| | finally: |
| | |
| | for h in hooks: |
| | h.remove() |
| |
|
| |
|
| | def compute_focus_summary( |
| | heatmap: np.ndarray, |
| | threshold: float = 0.5 |
| | ) -> str: |
| | """ |
| | Compute a human-readable summary of where the heatmap focuses. |
| | |
| | Analyzes the heatmap to describe the spatial distribution of high |
| | activation regions (e.g., "concentrated on upper-left", "diffuse across image"). |
| | |
| | Args: |
| | heatmap: 2D numpy array with values in [0, 1], shape (H, W) |
| | threshold: Threshold for considering a region as "high activation" |
| | |
| | Returns: |
| | Human-readable focus summary string |
| | """ |
| | if heatmap is None or heatmap.size == 0: |
| | return "no activation data available" |
| | |
| | |
| | heatmap = np.array(heatmap, dtype=np.float32) |
| | if heatmap.max() > 0: |
| | heatmap = heatmap / heatmap.max() |
| | |
| | h, w = heatmap.shape |
| | |
| | |
| | mask = heatmap > threshold |
| | if not mask.any(): |
| | |
| | mask = heatmap > (heatmap.max() * 0.5) |
| | |
| | if not mask.any(): |
| | return "very low activation across entire image" |
| | |
| | |
| | y_coords, x_coords = np.where(mask) |
| | |
| | |
| | centroid_y = y_coords.mean() / h |
| | centroid_x = x_coords.mean() / w |
| | |
| | |
| | spread_y = y_coords.std() / h if len(y_coords) > 1 else 0 |
| | spread_x = x_coords.std() / w if len(x_coords) > 1 else 0 |
| | spread = (spread_y + spread_x) / 2 |
| | |
| | |
| | coverage = mask.sum() / mask.size |
| | |
| | |
| | parts = [] |
| | |
| | |
| | if spread < 0.15: |
| | parts.append("highly concentrated") |
| | elif spread < 0.25: |
| | parts.append("moderately concentrated") |
| | else: |
| | parts.append("spread across") |
| | |
| | |
| | location_parts = [] |
| | |
| | |
| | if centroid_y < 0.33: |
| | location_parts.append("upper") |
| | elif centroid_y > 0.67: |
| | location_parts.append("lower") |
| | else: |
| | location_parts.append("middle") |
| | |
| | |
| | if centroid_x < 0.33: |
| | location_parts.append("left") |
| | elif centroid_x > 0.67: |
| | location_parts.append("right") |
| | else: |
| | location_parts.append("center") |
| | |
| | |
| | if location_parts == ["middle", "center"]: |
| | location = "central region" |
| | else: |
| | location = "-".join(location_parts) + " region" |
| | |
| | parts.append(location) |
| | |
| | |
| | if coverage > 0.4: |
| | parts.append(f"(~{int(coverage*100)}% of image)") |
| | |
| | summary = " ".join(parts) |
| | |
| | |
| | |
| | if centroid_y < 0.5 and 0.3 < centroid_x < 0.7 and spread < 0.2: |
| | summary += " (likely face/subject area)" |
| | elif spread > 0.3: |
| | summary += " (examining multiple regions)" |
| | |
| | return summary |
| |
|
| |
|