Spaces:
Sleeping
Sleeping
| """ | |
| DeepGuard β Heatmap Generation Module | |
| Strategy: | |
| 1. PRIMARY: Attention Rollout β extract multi-head attention matrices from | |
| ONNX intermediate outputs and roll them up through all layers. | |
| 2. FALLBACK: Frequency Anomaly + Gradient Saliency β if attention weights | |
| are not exported, compute a forensically meaningful heatmap | |
| using DCT frequency analysis and Sobel edge gradients. | |
| (Pure NumPy, <10ms, no additional inference passes.) | |
| """ | |
| import io | |
| import numpy as np | |
| from PIL import Image | |
| from scipy.ndimage import gaussian_filter | |
| import base64 | |
| from typing import Optional | |
| # --------------------------------------------------------------------------- | |
| # Public entry point | |
| # --------------------------------------------------------------------------- | |
| def generate_heatmap( | |
| image: Image.Image, | |
| output_dict: dict, | |
| confidence_score: float, | |
| ) -> str: | |
| """ | |
| Generate a transparent red/yellow heatmap overlay. | |
| Args: | |
| image: Original PIL image (any size). | |
| output_dict: Raw ONNX output dict {name: ndarray}. | |
| confidence_score: Model fake probability [0, 1]. | |
| Returns: | |
| data URI string: "data:image/png;base64,..." | |
| """ | |
| img224 = image.convert("RGB").resize((224, 224), Image.BILINEAR) | |
| img_arr = np.array(img224, dtype=np.float32) | |
| # Try attention rollout first | |
| attn_keys = [ | |
| k for k in output_dict | |
| if "attn" in k.lower() or "attention" in k.lower() | |
| ] | |
| heat_map = None | |
| if attn_keys: | |
| heat_map = _attention_rollout(output_dict, attn_keys) | |
| if heat_map is None: | |
| heat_map = _frequency_saliency(img_arr, confidence_score) | |
| overlay = _apply_overlay(img_arr, heat_map) | |
| return _encode_png(overlay) | |
| # --------------------------------------------------------------------------- | |
| # Strategy 1: Attention Rollout | |
| # --------------------------------------------------------------------------- | |
| def _attention_rollout(output_dict: dict, attn_keys: list) -> Optional[np.ndarray]: | |
| """ | |
| Roll up multi-head attention matrices across all transformer layers. | |
| Returns a normalized (224, 224) float32 array or None on failure. | |
| """ | |
| try: | |
| # Sort keys to ensure layer order (layer_0, layer_1, ...) | |
| attn_keys_sorted = sorted(attn_keys) | |
| rollout = None | |
| for key in attn_keys_sorted: | |
| attn = output_dict[key] # Expected shape: (1, heads, seq_len, seq_len) | |
| if attn.ndim != 4: | |
| continue | |
| attn = attn.squeeze(0) # (heads, seq_len, seq_len) | |
| attn = attn.mean(axis=0) # Average heads β (seq_len, seq_len) | |
| # Add residual identity (attention rollout formula) | |
| identity = np.eye(attn.shape[0], dtype=np.float32) | |
| attn = 0.5 * attn + 0.5 * identity | |
| attn = attn / (attn.sum(axis=-1, keepdims=True) + 1e-8) | |
| rollout = attn if rollout is None else np.matmul(rollout, attn) | |
| if rollout is None: | |
| return None | |
| # Row 0 = CLS token β attends to all patch tokens | |
| cls_attn = rollout[0, 1:] # Drop CLS itself β (num_patches,) | |
| num_patches = cls_attn.shape[0] | |
| side = int(np.sqrt(num_patches)) # 14 for ViT-base-patch16 | |
| if side * side != num_patches: | |
| return None | |
| patch_map = cls_attn.reshape(side, side) | |
| patch_map = (patch_map - patch_map.min()) / (patch_map.max() - patch_map.min() + 1e-8) | |
| # Upsample to 224Γ224 | |
| heat = _upsample(patch_map, 224, 224) | |
| heat = gaussian_filter(heat, sigma=8) | |
| return _normalize(heat) | |
| except Exception: | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Strategy 2: Frequency Anomaly + Sobel Saliency (pure NumPy fallback) | |
| # --------------------------------------------------------------------------- | |
| def _frequency_saliency(img_arr: np.ndarray, confidence_score: float) -> np.ndarray: | |
| """ | |
| Generate a heatmap from: | |
| - DCT/FFT frequency anomalies (AI images have characteristic frequency patterns) | |
| - Sobel gradient magnitude (AI fails at object/background boundaries) | |
| Both signals are combined and weighted by the confidence score. | |
| """ | |
| gray = 0.299 * img_arr[:, :, 0] + 0.587 * img_arr[:, :, 1] + 0.114 * img_arr[:, :, 2] | |
| gray_norm = gray / 255.0 | |
| # --- Frequency anomaly via 2D FFT --- | |
| fft = np.fft.fft2(gray_norm) | |
| fft_shift = np.fft.fftshift(fft) | |
| magnitude = np.log1p(np.abs(fft_shift)) | |
| # High-pass: keep frequencies above the center radius (AI images often | |
| # have unnaturally suppressed high-frequency noise) | |
| h, w = magnitude.shape | |
| cy, cx = h // 2, w // 2 | |
| Y, X = np.ogrid[:h, :w] | |
| r = np.sqrt((X - cx) ** 2 + (Y - cy) ** 2) | |
| # Anomaly score: deviation of high-freq energy from expected camera noise | |
| high_freq_mask = r > (min(h, w) * 0.15) | |
| freq_baseline = magnitude[high_freq_mask].mean() | |
| freq_map = np.abs(magnitude - freq_baseline) | |
| freq_map = _normalize(freq_map) | |
| # --- Sobel gradient magnitude --- | |
| ky = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=np.float32) | |
| kx = ky.T | |
| gx = _convolve2d(gray_norm, kx) | |
| gy = _convolve2d(gray_norm, ky) | |
| grad_map = np.sqrt(gx ** 2 + gy ** 2) | |
| grad_map = _normalize(grad_map) | |
| # Combine: weight by score β high-confidence β emphasize freq anomaly | |
| alpha = min(confidence_score * 1.2, 0.8) | |
| combined = alpha * freq_map + (1.0 - alpha) * grad_map | |
| # Smooth and normalize | |
| combined = gaussian_filter(combined, sigma=10) | |
| return _normalize(combined) | |
| def _convolve2d(img: np.ndarray, kernel: np.ndarray) -> np.ndarray: | |
| """Manual 2D convolution via stride tricks (no scipy dependency for this).""" | |
| from scipy.ndimage import convolve | |
| return convolve(img, kernel, mode="reflect") | |
| # --------------------------------------------------------------------------- | |
| # Colormap and overlay helpers | |
| # --------------------------------------------------------------------------- | |
| def _apply_overlay(img_arr: np.ndarray, heat: np.ndarray, alpha: float = 0.55) -> np.ndarray: | |
| """ | |
| Blend red/yellow heatmap over original image. | |
| Returns RGBA uint8 array (224, 224, 4). | |
| """ | |
| # Map heat [0,1] to RGBA: 0=transparent, 0.5=orange, 1.0=bright red | |
| r = np.ones_like(heat) # R channel: always full | |
| g = np.clip(1.0 - heat * 1.4, 0, 1) # G: fades out β red | |
| b = np.zeros_like(heat) # B: always 0 | |
| overlay_rgb = np.stack([r, g, b], axis=-1) # (224,224,3) float [0,1] | |
| overlay_alpha = np.clip(heat * alpha * 255, 0, 255) # (224,224) float | |
| # Blend: result = img * (1 - a) + color * a | |
| a3 = (overlay_alpha[:, :, np.newaxis] / 255.0) | |
| blended = (img_arr / 255.0) * (1.0 - a3) + overlay_rgb * a3 | |
| blended = np.clip(blended * 255, 0, 255).astype(np.uint8) | |
| # Add alpha channel | |
| alpha_ch = overlay_alpha.astype(np.uint8) | |
| # Keep full opacity everywhere, just use blend for color | |
| full_alpha = np.full((224, 224), 255, dtype=np.uint8) | |
| rgba = np.dstack([blended, full_alpha]) | |
| return rgba | |
| def _encode_png(rgba_arr: np.ndarray) -> str: | |
| """Encode RGBA array to data URI.""" | |
| pil_img = Image.fromarray(rgba_arr, mode="RGBA") | |
| buf = io.BytesIO() | |
| pil_img.save(buf, format="PNG", optimize=True) | |
| b64 = base64.b64encode(buf.getvalue()).decode("utf-8") | |
| return f"data:image/png;base64,{b64}" | |
| # --------------------------------------------------------------------------- | |
| # Utility helpers | |
| # --------------------------------------------------------------------------- | |
| def _normalize(arr: np.ndarray) -> np.ndarray: | |
| mn, mx = arr.min(), arr.max() | |
| if mx - mn < 1e-8: | |
| return np.zeros_like(arr, dtype=np.float32) | |
| return ((arr - mn) / (mx - mn)).astype(np.float32) | |
| def _upsample(patch_map: np.ndarray, target_h: int, target_w: int) -> np.ndarray: | |
| """Bilinear upsample a small 2D patch map to target size using PIL.""" | |
| pil = Image.fromarray((patch_map * 255).astype(np.uint8), mode="L") | |
| pil = pil.resize((target_w, target_h), Image.BILINEAR) | |
| return np.array(pil, dtype=np.float32) / 255.0 | |