""" DeepGuard — Heatmap Generation Module Strategy: 1. PRIMARY: Attention Rollout — extract multi-head attention matrices from ONNX intermediate outputs and roll them up through all layers. 2. FALLBACK: Frequency Anomaly + Gradient Saliency — if attention weights are not exported, compute a forensically meaningful heatmap using DCT frequency analysis and Sobel edge gradients. (Pure NumPy, <10ms, no additional inference passes.) """ import io import numpy as np from PIL import Image from scipy.ndimage import gaussian_filter import base64 from typing import Optional # --------------------------------------------------------------------------- # Public entry point # --------------------------------------------------------------------------- def generate_heatmap( image: Image.Image, output_dict: dict, confidence_score: float, ) -> str: """ Generate a transparent red/yellow heatmap overlay. Args: image: Original PIL image (any size). output_dict: Raw ONNX output dict {name: ndarray}. confidence_score: Model fake probability [0, 1]. Returns: data URI string: "data:image/png;base64,..." """ img224 = image.convert("RGB").resize((224, 224), Image.BILINEAR) img_arr = np.array(img224, dtype=np.float32) # Try attention rollout first attn_keys = [ k for k in output_dict if "attn" in k.lower() or "attention" in k.lower() ] heat_map = None if attn_keys: heat_map = _attention_rollout(output_dict, attn_keys) if heat_map is None: heat_map = _frequency_saliency(img_arr, confidence_score) overlay = _apply_overlay(img_arr, heat_map) return _encode_png(overlay) # --------------------------------------------------------------------------- # Strategy 1: Attention Rollout # --------------------------------------------------------------------------- def _attention_rollout(output_dict: dict, attn_keys: list) -> Optional[np.ndarray]: """ Roll up multi-head attention matrices across all transformer layers. Returns a normalized (224, 224) float32 array or None on failure. """ try: # Sort keys to ensure layer order (layer_0, layer_1, ...) attn_keys_sorted = sorted(attn_keys) rollout = None for key in attn_keys_sorted: attn = output_dict[key] # Expected shape: (1, heads, seq_len, seq_len) if attn.ndim != 4: continue attn = attn.squeeze(0) # (heads, seq_len, seq_len) attn = attn.mean(axis=0) # Average heads → (seq_len, seq_len) # Add residual identity (attention rollout formula) identity = np.eye(attn.shape[0], dtype=np.float32) attn = 0.5 * attn + 0.5 * identity attn = attn / (attn.sum(axis=-1, keepdims=True) + 1e-8) rollout = attn if rollout is None else np.matmul(rollout, attn) if rollout is None: return None # Row 0 = CLS token → attends to all patch tokens cls_attn = rollout[0, 1:] # Drop CLS itself → (num_patches,) num_patches = cls_attn.shape[0] side = int(np.sqrt(num_patches)) # 14 for ViT-base-patch16 if side * side != num_patches: return None patch_map = cls_attn.reshape(side, side) patch_map = (patch_map - patch_map.min()) / (patch_map.max() - patch_map.min() + 1e-8) # Upsample to 224×224 heat = _upsample(patch_map, 224, 224) heat = gaussian_filter(heat, sigma=8) return _normalize(heat) except Exception: return None # --------------------------------------------------------------------------- # Strategy 2: Frequency Anomaly + Sobel Saliency (pure NumPy fallback) # --------------------------------------------------------------------------- def _frequency_saliency(img_arr: np.ndarray, confidence_score: float) -> np.ndarray: """ Generate a heatmap from: - DCT/FFT frequency anomalies (AI images have characteristic frequency patterns) - Sobel gradient magnitude (AI fails at object/background boundaries) Both signals are combined and weighted by the confidence score. """ gray = 0.299 * img_arr[:, :, 0] + 0.587 * img_arr[:, :, 1] + 0.114 * img_arr[:, :, 2] gray_norm = gray / 255.0 # --- Frequency anomaly via 2D FFT --- fft = np.fft.fft2(gray_norm) fft_shift = np.fft.fftshift(fft) magnitude = np.log1p(np.abs(fft_shift)) # High-pass: keep frequencies above the center radius (AI images often # have unnaturally suppressed high-frequency noise) h, w = magnitude.shape cy, cx = h // 2, w // 2 Y, X = np.ogrid[:h, :w] r = np.sqrt((X - cx) ** 2 + (Y - cy) ** 2) # Anomaly score: deviation of high-freq energy from expected camera noise high_freq_mask = r > (min(h, w) * 0.15) freq_baseline = magnitude[high_freq_mask].mean() freq_map = np.abs(magnitude - freq_baseline) freq_map = _normalize(freq_map) # --- Sobel gradient magnitude --- ky = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=np.float32) kx = ky.T gx = _convolve2d(gray_norm, kx) gy = _convolve2d(gray_norm, ky) grad_map = np.sqrt(gx ** 2 + gy ** 2) grad_map = _normalize(grad_map) # Combine: weight by score — high-confidence → emphasize freq anomaly alpha = min(confidence_score * 1.2, 0.8) combined = alpha * freq_map + (1.0 - alpha) * grad_map # Smooth and normalize combined = gaussian_filter(combined, sigma=10) return _normalize(combined) def _convolve2d(img: np.ndarray, kernel: np.ndarray) -> np.ndarray: """Manual 2D convolution via stride tricks (no scipy dependency for this).""" from scipy.ndimage import convolve return convolve(img, kernel, mode="reflect") # --------------------------------------------------------------------------- # Colormap and overlay helpers # --------------------------------------------------------------------------- def _apply_overlay(img_arr: np.ndarray, heat: np.ndarray, alpha: float = 0.55) -> np.ndarray: """ Blend red/yellow heatmap over original image. Returns RGBA uint8 array (224, 224, 4). """ # Map heat [0,1] to RGBA: 0=transparent, 0.5=orange, 1.0=bright red r = np.ones_like(heat) # R channel: always full g = np.clip(1.0 - heat * 1.4, 0, 1) # G: fades out → red b = np.zeros_like(heat) # B: always 0 overlay_rgb = np.stack([r, g, b], axis=-1) # (224,224,3) float [0,1] overlay_alpha = np.clip(heat * alpha * 255, 0, 255) # (224,224) float # Blend: result = img * (1 - a) + color * a a3 = (overlay_alpha[:, :, np.newaxis] / 255.0) blended = (img_arr / 255.0) * (1.0 - a3) + overlay_rgb * a3 blended = np.clip(blended * 255, 0, 255).astype(np.uint8) # Add alpha channel alpha_ch = overlay_alpha.astype(np.uint8) # Keep full opacity everywhere, just use blend for color full_alpha = np.full((224, 224), 255, dtype=np.uint8) rgba = np.dstack([blended, full_alpha]) return rgba def _encode_png(rgba_arr: np.ndarray) -> str: """Encode RGBA array to data URI.""" pil_img = Image.fromarray(rgba_arr, mode="RGBA") buf = io.BytesIO() pil_img.save(buf, format="PNG", optimize=True) b64 = base64.b64encode(buf.getvalue()).decode("utf-8") return f"data:image/png;base64,{b64}" # --------------------------------------------------------------------------- # Utility helpers # --------------------------------------------------------------------------- def _normalize(arr: np.ndarray) -> np.ndarray: mn, mx = arr.min(), arr.max() if mx - mn < 1e-8: return np.zeros_like(arr, dtype=np.float32) return ((arr - mn) / (mx - mn)).astype(np.float32) def _upsample(patch_map: np.ndarray, target_h: int, target_w: int) -> np.ndarray: """Bilinear upsample a small 2D patch map to target size using PIL.""" pil = Image.fromarray((patch_map * 255).astype(np.uint8), mode="L") pil = pil.resize((target_w, target_h), Image.BILINEAR) return np.array(pil, dtype=np.float32) / 255.0