File size: 5,307 Bytes
56ec9ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""

Visualization utilities for anomaly detection results

"""

import numpy as np
import cv2
import matplotlib.pyplot as plt
from matplotlib import cm
from PIL import Image
from typing import Tuple, Optional
import config


def apply_heatmap(image: np.ndarray, 

                  anomaly_map: np.ndarray, 

                  alpha: float = 0.4,

                  colormap: str = "jet") -> np.ndarray:
    """

    Overlay anomaly heatmap on original image

    

    Args:

        image: Original image [H, W, 3] in RGB, range [0, 255]

        anomaly_map: Anomaly map [H', W'] normalized to [0, 1]

        alpha: Overlay transparency

        colormap: Matplotlib colormap name

        

    Returns:

        Overlayed image [H, W, 3] in RGB

    """
    H, W = image.shape[:2]
    
    # Resize anomaly map to match image size
    anomaly_map_resized = cv2.resize(anomaly_map, (W, H))
    
    # Normalize to [0, 1]
    anomaly_map_norm = (anomaly_map_resized - anomaly_map_resized.min()) / \
                       (anomaly_map_resized.max() - anomaly_map_resized.min() + 1e-8)
    
    # Apply colormap
    cmap = cm.get_cmap(colormap)
    heatmap = cmap(anomaly_map_norm)[:, :, :3]  # Remove alpha channel
    heatmap = (heatmap * 255).astype(np.uint8)
    
    # Ensure image is uint8
    if image.max() <= 1.0:
        image = (image * 255).astype(np.uint8)
    else:
        image = image.astype(np.uint8)
    
    # Blend image and heatmap
    overlayed = cv2.addWeighted(image, 1 - alpha, heatmap, alpha, 0)
    
    return overlayed


def create_result_visualization(original_image: Image.Image,

                                anomaly_score: float,

                                anomaly_map: np.ndarray,

                                threshold: float = 0.5,

                                ground_truth: Optional[np.ndarray] = None) -> Image.Image:
    """

    Create comprehensive result visualization

    

    Args:

        original_image: PIL Image

        anomaly_score: Image-level anomaly score

        anomaly_map: Pixel-level anomaly map

        threshold: Decision threshold

        ground_truth: Optional ground truth mask

        

    Returns:

        Combined visualization as PIL Image

    """
    # Convert to numpy array
    img_np = np.array(original_image)
    
    # Create figure with subplots
    n_cols = 4 if ground_truth is not None else 3
    fig, axes = plt.subplots(1, n_cols, figsize=(n_cols * 4, 4))
    
    # Original image
    axes[0].imshow(img_np)
    axes[0].set_title("Original Image")
    axes[0].axis('off')
    
    # Anomaly heatmap
    heatmap_overlay = apply_heatmap(img_np, anomaly_map, alpha=config.HEATMAP_ALPHA)
    axes[1].imshow(heatmap_overlay)
    
    prediction = "DEFECTIVE" if anomaly_score > threshold else "NORMAL"
    color = "red" if anomaly_score > threshold else "green"
    axes[1].set_title(f"Prediction: {prediction}\nScore: {anomaly_score:.3f}", 
                     color=color, fontweight='bold')
    axes[1].axis('off')
    
    # Raw anomaly map
    im = axes[2].imshow(anomaly_map, cmap=config.HEATMAP_COLORMAP)
    axes[2].set_title("Anomaly Map")
    axes[2].axis('off')
    plt.colorbar(im, ax=axes[2], fraction=0.046)
    
    # Ground truth (if available)
    if ground_truth is not None:
        axes[3].imshow(ground_truth, cmap='gray')
        axes[3].set_title("Ground Truth")
        axes[3].axis('off')
    
    plt.tight_layout()
    
    # Convert to PIL Image
    fig.canvas.draw()
    vis_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    vis_np = vis_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    vis_pil = Image.fromarray(vis_np)
    
    plt.close(fig)
    
    return vis_pil


def plot_roc_curve(fpr: np.ndarray, tpr: np.ndarray, auc: float, 

                   save_path: Optional[str] = None):
    """Plot ROC curve"""
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, linewidth=2, label=f'ROC (AUC = {auc:.3f})')
    plt.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random')
    plt.xlabel('False Positive Rate', fontsize=12)
    plt.ylabel('True Positive Rate', fontsize=12)
    plt.title('ROC Curve - Image-Level Anomaly Detection', fontsize=14)
    plt.legend(fontsize=11)
    plt.grid(alpha=0.3)
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"ROC curve saved to {save_path}")
    
    plt.close()


def save_prediction(image: Image.Image, 

                   anomaly_score: float,

                   anomaly_map: np.ndarray,

                   save_path: str,

                   threshold: float = 0.5):
    """Save prediction result with overlay"""
    img_np = np.array(image)
    overlay = apply_heatmap(img_np, anomaly_map, alpha=config.HEATMAP_ALPHA)
    
    # Add text annotation
    prediction = "DEFECTIVE" if anomaly_score > threshold else "NORMAL"
    color = (255, 0, 0) if anomaly_score > threshold else (0, 255, 0)
    
    cv2.putText(overlay, f"{prediction} ({anomaly_score:.3f})", 
                (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 
                1, color, 2, cv2.LINE_AA)
    
    # Save
    Image.fromarray(overlay).save(save_path)