Deepguard-api / heatmap.py
suyash-77's picture
Upload 9 files
a02f72f verified
"""
DeepGuard β€” Heatmap Generation Module
Strategy:
1. PRIMARY: Attention Rollout β€” extract multi-head attention matrices from
ONNX intermediate outputs and roll them up through all layers.
2. FALLBACK: Frequency Anomaly + Gradient Saliency β€” if attention weights
are not exported, compute a forensically meaningful heatmap
using DCT frequency analysis and Sobel edge gradients.
(Pure NumPy, <10ms, no additional inference passes.)
"""
import io
import numpy as np
from PIL import Image
from scipy.ndimage import gaussian_filter
import base64
from typing import Optional
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
def generate_heatmap(
image: Image.Image,
output_dict: dict,
confidence_score: float,
) -> str:
"""
Generate a transparent red/yellow heatmap overlay.
Args:
image: Original PIL image (any size).
output_dict: Raw ONNX output dict {name: ndarray}.
confidence_score: Model fake probability [0, 1].
Returns:
data URI string: "data:image/png;base64,..."
"""
img224 = image.convert("RGB").resize((224, 224), Image.BILINEAR)
img_arr = np.array(img224, dtype=np.float32)
# Try attention rollout first
attn_keys = [
k for k in output_dict
if "attn" in k.lower() or "attention" in k.lower()
]
heat_map = None
if attn_keys:
heat_map = _attention_rollout(output_dict, attn_keys)
if heat_map is None:
heat_map = _frequency_saliency(img_arr, confidence_score)
overlay = _apply_overlay(img_arr, heat_map)
return _encode_png(overlay)
# ---------------------------------------------------------------------------
# Strategy 1: Attention Rollout
# ---------------------------------------------------------------------------
def _attention_rollout(output_dict: dict, attn_keys: list) -> Optional[np.ndarray]:
"""
Roll up multi-head attention matrices across all transformer layers.
Returns a normalized (224, 224) float32 array or None on failure.
"""
try:
# Sort keys to ensure layer order (layer_0, layer_1, ...)
attn_keys_sorted = sorted(attn_keys)
rollout = None
for key in attn_keys_sorted:
attn = output_dict[key] # Expected shape: (1, heads, seq_len, seq_len)
if attn.ndim != 4:
continue
attn = attn.squeeze(0) # (heads, seq_len, seq_len)
attn = attn.mean(axis=0) # Average heads β†’ (seq_len, seq_len)
# Add residual identity (attention rollout formula)
identity = np.eye(attn.shape[0], dtype=np.float32)
attn = 0.5 * attn + 0.5 * identity
attn = attn / (attn.sum(axis=-1, keepdims=True) + 1e-8)
rollout = attn if rollout is None else np.matmul(rollout, attn)
if rollout is None:
return None
# Row 0 = CLS token β†’ attends to all patch tokens
cls_attn = rollout[0, 1:] # Drop CLS itself β†’ (num_patches,)
num_patches = cls_attn.shape[0]
side = int(np.sqrt(num_patches)) # 14 for ViT-base-patch16
if side * side != num_patches:
return None
patch_map = cls_attn.reshape(side, side)
patch_map = (patch_map - patch_map.min()) / (patch_map.max() - patch_map.min() + 1e-8)
# Upsample to 224Γ—224
heat = _upsample(patch_map, 224, 224)
heat = gaussian_filter(heat, sigma=8)
return _normalize(heat)
except Exception:
return None
# ---------------------------------------------------------------------------
# Strategy 2: Frequency Anomaly + Sobel Saliency (pure NumPy fallback)
# ---------------------------------------------------------------------------
def _frequency_saliency(img_arr: np.ndarray, confidence_score: float) -> np.ndarray:
"""
Generate a heatmap from:
- DCT/FFT frequency anomalies (AI images have characteristic frequency patterns)
- Sobel gradient magnitude (AI fails at object/background boundaries)
Both signals are combined and weighted by the confidence score.
"""
gray = 0.299 * img_arr[:, :, 0] + 0.587 * img_arr[:, :, 1] + 0.114 * img_arr[:, :, 2]
gray_norm = gray / 255.0
# --- Frequency anomaly via 2D FFT ---
fft = np.fft.fft2(gray_norm)
fft_shift = np.fft.fftshift(fft)
magnitude = np.log1p(np.abs(fft_shift))
# High-pass: keep frequencies above the center radius (AI images often
# have unnaturally suppressed high-frequency noise)
h, w = magnitude.shape
cy, cx = h // 2, w // 2
Y, X = np.ogrid[:h, :w]
r = np.sqrt((X - cx) ** 2 + (Y - cy) ** 2)
# Anomaly score: deviation of high-freq energy from expected camera noise
high_freq_mask = r > (min(h, w) * 0.15)
freq_baseline = magnitude[high_freq_mask].mean()
freq_map = np.abs(magnitude - freq_baseline)
freq_map = _normalize(freq_map)
# --- Sobel gradient magnitude ---
ky = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=np.float32)
kx = ky.T
gx = _convolve2d(gray_norm, kx)
gy = _convolve2d(gray_norm, ky)
grad_map = np.sqrt(gx ** 2 + gy ** 2)
grad_map = _normalize(grad_map)
# Combine: weight by score β€” high-confidence β†’ emphasize freq anomaly
alpha = min(confidence_score * 1.2, 0.8)
combined = alpha * freq_map + (1.0 - alpha) * grad_map
# Smooth and normalize
combined = gaussian_filter(combined, sigma=10)
return _normalize(combined)
def _convolve2d(img: np.ndarray, kernel: np.ndarray) -> np.ndarray:
"""Manual 2D convolution via stride tricks (no scipy dependency for this)."""
from scipy.ndimage import convolve
return convolve(img, kernel, mode="reflect")
# ---------------------------------------------------------------------------
# Colormap and overlay helpers
# ---------------------------------------------------------------------------
def _apply_overlay(img_arr: np.ndarray, heat: np.ndarray, alpha: float = 0.55) -> np.ndarray:
"""
Blend red/yellow heatmap over original image.
Returns RGBA uint8 array (224, 224, 4).
"""
# Map heat [0,1] to RGBA: 0=transparent, 0.5=orange, 1.0=bright red
r = np.ones_like(heat) # R channel: always full
g = np.clip(1.0 - heat * 1.4, 0, 1) # G: fades out β†’ red
b = np.zeros_like(heat) # B: always 0
overlay_rgb = np.stack([r, g, b], axis=-1) # (224,224,3) float [0,1]
overlay_alpha = np.clip(heat * alpha * 255, 0, 255) # (224,224) float
# Blend: result = img * (1 - a) + color * a
a3 = (overlay_alpha[:, :, np.newaxis] / 255.0)
blended = (img_arr / 255.0) * (1.0 - a3) + overlay_rgb * a3
blended = np.clip(blended * 255, 0, 255).astype(np.uint8)
# Add alpha channel
alpha_ch = overlay_alpha.astype(np.uint8)
# Keep full opacity everywhere, just use blend for color
full_alpha = np.full((224, 224), 255, dtype=np.uint8)
rgba = np.dstack([blended, full_alpha])
return rgba
def _encode_png(rgba_arr: np.ndarray) -> str:
"""Encode RGBA array to data URI."""
pil_img = Image.fromarray(rgba_arr, mode="RGBA")
buf = io.BytesIO()
pil_img.save(buf, format="PNG", optimize=True)
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return f"data:image/png;base64,{b64}"
# ---------------------------------------------------------------------------
# Utility helpers
# ---------------------------------------------------------------------------
def _normalize(arr: np.ndarray) -> np.ndarray:
mn, mx = arr.min(), arr.max()
if mx - mn < 1e-8:
return np.zeros_like(arr, dtype=np.float32)
return ((arr - mn) / (mx - mn)).astype(np.float32)
def _upsample(patch_map: np.ndarray, target_h: int, target_w: int) -> np.ndarray:
"""Bilinear upsample a small 2D patch map to target size using PIL."""
pil = Image.fromarray((patch_map * 255).astype(np.uint8), mode="L")
pil = pil.resize((target_w, target_h), Image.BILINEAR)
return np.array(pil, dtype=np.float32) / 255.0