iad-explainable-hf / core /visualization.py
Parikshit Rathode
initial commit
c5732cc
"""
Visualization module for anomaly detection results.
This module provides functions to create visual outputs including heatmaps,
overlays, and predicted mask visualizations.
"""
import cv2
import numpy as np
from config import HEATMAP_ALPHA, OVERLAY_ALPHA
def create_overlay(image: np.ndarray, heatmap: np.ndarray, model_name: str) -> np.ndarray:
"""
Create an overlay of the heatmap on the original image.
Args:
image: Original image in RGB format (H, W, 3)
heatmap: Normalized heatmap (H, W) in range [0, 1]
model_name: Name of the model for model-specific handling
Returns:
Overlay image in RGB format (H, W, 3)
"""
image_resized = cv2.resize(image, (256, 256))
# Convert heatmap to uint8 and apply colormap
heatmap_uint8 = (heatmap * 255).astype(np.uint8)
heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
# Model-specific handling for EfficientAD padding
if model_name == "efficientad":
# Mask out zero values (padding) to show original image
mask_0 = (heatmap == 0)[..., np.newaxis]
overlay = np.where(mask_0, image_resized, cv2.addWeighted(image_resized, OVERLAY_ALPHA, heatmap_color, HEATMAP_ALPHA, 0))
else:
overlay = cv2.addWeighted(image_resized, OVERLAY_ALPHA, heatmap_color, HEATMAP_ALPHA, 0)
return overlay
def create_mask_visualization(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""
Create a visualization of the predicted mask overlaid on the image.
Args:
image: Original image in RGB format (H, W, 3)
mask: Binary mask (H, W) where non-zero values indicate anomaly
Returns:
Visualization image with semi-transparent red mask and contours
"""
image_resized = cv2.resize(image, (256, 256))
vis_img = image_resized.copy()
if np.any(mask):
# Create a red color mask
color_mask = np.zeros_like(image_resized)
color_mask[mask > 0] = [255, 0, 0] # RGB Red
# Blend with original image
vis_img = cv2.addWeighted(vis_img, 0.7, color_mask, 0.3, 0)
# Draw contours
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(vis_img, contours, -1, (255, 255, 255), 2)
return vis_img
def draw_bounding_boxes(overlay: np.ndarray, mask_vis: np.ndarray, bboxes: list):
"""
Draw bounding boxes on both overlay and mask visualization images.
Args:
overlay: Overlay image to draw on (modified in-place)
mask_vis: Mask visualization image to draw on (modified in-place)
bboxes: List of bounding boxes [x1, y1, x2, y2]
"""
for (x1, y1, x2, y2) in bboxes:
# Green boxes on overlay
cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Blue boxes on mask visualization
cv2.rectangle(mask_vis, (x1, y1), (x2, y2), (255, 0, 0), 2)
def create_heatmap_color(heatmap: np.ndarray, model_name: str) -> np.ndarray:
"""
Create a colored heatmap image suitable for display.
Args:
heatmap: Normalized heatmap (H, W) in range [0, 1]
model_name: Name of the model for model-specific handling
Returns:
Colored heatmap in RGB format (H, W, 3)
"""
heatmap_uint8 = (heatmap * 255).astype(np.uint8)
heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
# For EfficientAD, make padding (zero values) black
if model_name == "efficientad":
heatmap_color[heatmap == 0] = [0, 0, 0]
return heatmap_color
def create_visuals(
image: np.ndarray,
heatmap: np.ndarray,
mask: np.ndarray,
bboxes: list,
model_name: str
) -> tuple:
"""
Create all visualization outputs for a single inference result.
Args:
image: Original input image in RGB format
heatmap: Normalized heatmap (H, W)
mask: Binary mask (H, W)
bboxes: List of bounding boxes
model_name: Name of the model
Returns:
tuple: (original_resized, heatmap_color, overlay, mask_vis)
"""
# Resize original image to 256x256
original_resized = cv2.resize(image, (256, 256))
# Create heatmap visualization
heatmap_color = create_heatmap_color(heatmap, model_name)
# Create overlay
overlay = create_overlay(image, heatmap, model_name)
# Create mask visualization
mask_vis = create_mask_visualization(image, mask)
# Draw bounding boxes on overlay and mask visualization
draw_bounding_boxes(overlay, mask_vis, bboxes)
return original_resized, heatmap_color, overlay, mask_vis