File size: 4,764 Bytes
c5732cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
Visualization module for anomaly detection results.

This module provides functions to create visual outputs including heatmaps,
overlays, and predicted mask visualizations.
"""

import cv2
import numpy as np

from config import HEATMAP_ALPHA, OVERLAY_ALPHA


def create_overlay(image: np.ndarray, heatmap: np.ndarray, model_name: str) -> np.ndarray:
    """
    Create an overlay of the heatmap on the original image.

    Args:
        image: Original image in RGB format (H, W, 3)
        heatmap: Normalized heatmap (H, W) in range [0, 1]
        model_name: Name of the model for model-specific handling

    Returns:
        Overlay image in RGB format (H, W, 3)
    """
    image_resized = cv2.resize(image, (256, 256))

    # Convert heatmap to uint8 and apply colormap
    heatmap_uint8 = (heatmap * 255).astype(np.uint8)
    heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
    heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)

    # Model-specific handling for EfficientAD padding
    if model_name == "efficientad":
        # Mask out zero values (padding) to show original image
        mask_0 = (heatmap == 0)[..., np.newaxis]
        overlay = np.where(mask_0, image_resized, cv2.addWeighted(image_resized, OVERLAY_ALPHA, heatmap_color, HEATMAP_ALPHA, 0))
    else:
        overlay = cv2.addWeighted(image_resized, OVERLAY_ALPHA, heatmap_color, HEATMAP_ALPHA, 0)

    return overlay


def create_mask_visualization(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
    """
    Create a visualization of the predicted mask overlaid on the image.

    Args:
        image: Original image in RGB format (H, W, 3)
        mask: Binary mask (H, W) where non-zero values indicate anomaly

    Returns:
        Visualization image with semi-transparent red mask and contours
    """
    image_resized = cv2.resize(image, (256, 256))
    vis_img = image_resized.copy()

    if np.any(mask):
        # Create a red color mask
        color_mask = np.zeros_like(image_resized)
        color_mask[mask > 0] = [255, 0, 0]  # RGB Red

        # Blend with original image
        vis_img = cv2.addWeighted(vis_img, 0.7, color_mask, 0.3, 0)

        # Draw contours
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(vis_img, contours, -1, (255, 255, 255), 2)

    return vis_img


def draw_bounding_boxes(overlay: np.ndarray, mask_vis: np.ndarray, bboxes: list):
    """
    Draw bounding boxes on both overlay and mask visualization images.

    Args:
        overlay: Overlay image to draw on (modified in-place)
        mask_vis: Mask visualization image to draw on (modified in-place)
        bboxes: List of bounding boxes [x1, y1, x2, y2]
    """
    for (x1, y1, x2, y2) in bboxes:
        # Green boxes on overlay
        cv2.rectangle(overlay, (x1, y1), (x2, y2), (0, 255, 0), 2)
        # Blue boxes on mask visualization
        cv2.rectangle(mask_vis, (x1, y1), (x2, y2), (255, 0, 0), 2)


def create_heatmap_color(heatmap: np.ndarray, model_name: str) -> np.ndarray:
    """
    Create a colored heatmap image suitable for display.

    Args:
        heatmap: Normalized heatmap (H, W) in range [0, 1]
        model_name: Name of the model for model-specific handling

    Returns:
        Colored heatmap in RGB format (H, W, 3)
    """
    heatmap_uint8 = (heatmap * 255).astype(np.uint8)
    heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
    heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)

    # For EfficientAD, make padding (zero values) black
    if model_name == "efficientad":
        heatmap_color[heatmap == 0] = [0, 0, 0]

    return heatmap_color


def create_visuals(
    image: np.ndarray,
    heatmap: np.ndarray,
    mask: np.ndarray,
    bboxes: list,
    model_name: str
) -> tuple:
    """
    Create all visualization outputs for a single inference result.

    Args:
        image: Original input image in RGB format
        heatmap: Normalized heatmap (H, W)
        mask: Binary mask (H, W)
        bboxes: List of bounding boxes
        model_name: Name of the model

    Returns:
        tuple: (original_resized, heatmap_color, overlay, mask_vis)
    """
    # Resize original image to 256x256
    original_resized = cv2.resize(image, (256, 256))

    # Create heatmap visualization
    heatmap_color = create_heatmap_color(heatmap, model_name)

    # Create overlay
    overlay = create_overlay(image, heatmap, model_name)

    # Create mask visualization
    mask_vis = create_mask_visualization(image, mask)

    # Draw bounding boxes on overlay and mask visualization
    draw_bounding_boxes(overlay, mask_vis, bboxes)

    return original_resized, heatmap_color, overlay, mask_vis