Image-Forensics-Detect / explainability /spectral_heatmap.py
dk2430098's picture
Upload folder using huggingface_hub
928b74f verified
"""
explainability/spectral_heatmap.py
------------------------------------
Spectral Heatmap Visualization for FFT/DCT Branch.
STATUS: COMPLETE
Converts the FFT log-magnitude spectrum from the spectral branch
into a colorized heatmap suitable for display in the web UI.
Also generates a side-by-side comparison of:
- Log-magnitude spectrum
- High-frequency anomaly regions highlighted
Output:
- heatmap_b64 : base64-encoded JPEG of the spectral heatmap
- annotated_b64 : base64-encoded JPEG with anomaly rings annotated
"""
import numpy as np
import cv2
from utils.image_utils import encode_image_to_base64
def render_spectral_heatmap(
spectrum_map: np.ndarray,
img: np.ndarray,
) -> dict:
"""
Render the FFT spectrum as a colored heatmap and annotate anomaly peaks.
Args:
spectrum_map : (H, W) float32 in [0, 1] β€” from spectral_branch
img : (H, W, 3) float32 in [0, 1] β€” original image
Returns:
dict with:
"spectrum_b64" : base64 JPEG of raw spectrum heatmap
"annotated_b64" : base64 JPEG with anomaly rings drawn
"""
H, W = spectrum_map.shape
# ── 1. Colorize spectrum ───────────────────────────────────────
spec_u8 = (spectrum_map * 255).astype(np.uint8)
spec_colored = cv2.applyColorMap(spec_u8, cv2.COLORMAP_INFERNO)
spec_colored_rgb = cv2.cvtColor(spec_colored, cv2.COLOR_BGR2RGB)
spectrum_b64 = encode_image_to_base64(spec_colored_rgb)
# ── 2. Annotate high-frequency anomaly rings ───────────────────
annotated = spec_colored_rgb.copy()
cy, cx = H // 2, W // 2
# Draw reference rings at 10%, 25%, 45% of half-diagonal
half_diag = int(min(H, W) * 0.5)
radii = [
("LF", int(half_diag * 0.10), (0, 200, 0)), # Low freq (green)
("MF", int(half_diag * 0.25), (255, 200, 0)), # Mid freq (yellow)
("HF", int(half_diag * 0.45), (255, 50, 50)), # High freq (red)
]
for label, r, color in radii:
cv2.circle(annotated, (cx, cy), r, color, 1)
cv2.putText(
annotated, label,
(cx + r + 3, cy),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA
)
# Mark DC component (center)
cv2.circle(annotated, (cx, cy), 3, (255, 255, 255), -1)
# Detect and mark anomalous peaks in high-frequency band
y_idx, x_idx = np.ogrid[:H, :W]
dist = np.sqrt((y_idx - cy) ** 2 + (x_idx - cx) ** 2)
hf_mask = (dist > half_diag * 0.15) & (dist < half_diag * 0.45)
hf_spectrum = spectrum_map * hf_mask
# Find top-5 peak locations
flat_idx = np.argsort(hf_spectrum.ravel())[::-1][:5]
peak_coords = np.unravel_index(flat_idx, hf_spectrum.shape)
for py, px in zip(peak_coords[0], peak_coords[1]):
if hf_spectrum[py, px] > 0.5: # Only mark significant peaks
cv2.drawMarker(
annotated, (px, py), (0, 255, 255),
cv2.MARKER_CROSS, 8, 1, cv2.LINE_AA
)
annotated_b64 = encode_image_to_base64(annotated)
return {
"spectrum_b64": spectrum_b64,
"annotated_b64": annotated_b64,
}
def render_noise_map(noise_map: np.ndarray) -> str:
"""
Render the diffusion branch's residual noise map as a colorized heatmap.
Args:
noise_map : (H, W) float32 in [0, 1]
Returns:
base64-encoded JPEG string
"""
noise_u8 = (np.clip(noise_map, 0, 1) * 255).astype(np.uint8)
colored = cv2.applyColorMap(noise_u8, cv2.COLORMAP_VIRIDIS)
colored_rgb = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
return encode_image_to_base64(colored_rgb)
def render_edge_map(edge_map: np.ndarray) -> str:
"""
Render the edge branch's edge magnitude map as a colorized heatmap.
Args:
edge_map : (H, W) float32 in [0, 1]
Returns:
base64-encoded JPEG string
"""
edge_u8 = (np.clip(edge_map, 0, 1) * 255).astype(np.uint8)
colored = cv2.applyColorMap(edge_u8, cv2.COLORMAP_BONE)
colored_rgb = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
return encode_image_to_base64(colored_rgb)