File size: 12,989 Bytes
be5c319
 
 
 
 
 
 
 
 
a01dc02
 
be5c319
a01dc02
be5c319
 
a01dc02
 
 
be5c319
 
 
 
 
a01dc02
be5c319
 
 
 
a01dc02
be5c319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a01dc02
be5c319
a01dc02
 
be5c319
 
 
 
 
 
 
a01dc02
be5c319
a01dc02
 
be5c319
a01dc02
 
be5c319
 
 
 
 
 
a01dc02
be5c319
 
 
a01dc02
be5c319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a01dc02
be5c319
a01dc02
be5c319
a01dc02
 
be5c319
a01dc02
 
be5c319
 
a01dc02
be5c319
 
 
 
 
 
a01dc02
be5c319
 
 
 
 
 
 
 
a01dc02
be5c319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a01dc02
be5c319
a01dc02
be5c319
 
 
a01dc02
be5c319
a01dc02
be5c319
 
a01dc02
be5c319
 
 
a01dc02
be5c319
 
 
 
 
 
 
a01dc02
be5c319
 
 
 
a01dc02
 
 
be5c319
 
 
 
 
 
a01dc02
be5c319
 
 
 
 
 
a01dc02
be5c319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a01dc02
be5c319
a01dc02
be5c319
 
 
 
 
 
 
 
a01dc02
be5c319
 
 
 
a01dc02
 
be5c319
 
 
 
a01dc02
be5c319
a01dc02
 
be5c319
a01dc02
 
be5c319
 
 
 
 
 
a01dc02
be5c319
 
 
 
 
 
a01dc02
be5c319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a01dc02
be5c319
 
a01dc02
 
be5c319
 
 
a01dc02
be5c319
 
 
 
a01dc02
be5c319
a01dc02
be5c319
 
a01dc02
be5c319
 
a01dc02
be5c319
 
a01dc02
 
be5c319
a01dc02
 
be5c319
 
 
 
 
a01dc02
be5c319
 
 
 
 
 
 
a01dc02
be5c319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a01dc02
be5c319
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
"""
Utility Functions Module

This module provides helper functions for image preprocessing, heatmap manipulation,
visualization, and data conversion used throughout the ViT auditing toolkit.

Author: ViT-XAI-Dashboard Team
License: MIT
"""

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image


def preprocess_image(image, target_size=224):
    """
    Preprocess an image for Vision Transformer model input.

    This function handles loading images from file paths, converts them to RGB format,
    and resizes them to the target dimensions required by ViT models.

    Args:
        image (PIL.Image or str): Input image as a PIL Image object or file path string.
        target_size (int, optional): Target square size for resizing. Defaults to 224,
            which is the standard input size for most ViT models.

    Returns:
        PIL.Image: Preprocessed RGB image resized to (target_size, target_size).

    Example:
        >>> # From file path
        >>> img = preprocess_image("path/to/image.jpg")

        >>> # From PIL Image
        >>> from PIL import Image
        >>> img = Image.open("cat.jpg")
        >>> processed_img = preprocess_image(img, target_size=384)

    Note:
        - Grayscale and RGBA images are automatically converted to RGB
        - Maintains aspect ratio is not preserved; images are center-cropped and resized
        - No normalization is applied; use model processor for that
    """
    # If input is a file path string, load the image
    if isinstance(image, str):
        image = Image.open(image)

    # Convert to RGB if necessary (handles grayscale, RGBA, etc.)
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Resize image to target dimensions
    # Uses LANCZOS resampling for high-quality downsampling
    image = image.resize((target_size, target_size))

    return image


def normalize_heatmap(heatmap):
    """
    Normalize a heatmap array to the [0, 1] range using min-max scaling.

    This function is essential for visualizing heatmaps with consistent color mapping,
    regardless of the original value range. It handles edge cases where all values
    are identical.

    Args:
        heatmap (np.ndarray): Input heatmap array of any shape. Can contain any
            numeric values (int or float).

    Returns:
        np.ndarray: Normalized heatmap with values in [0, 1] range, preserving
            the original shape and relative differences between values.

    Example:
        >>> heatmap = np.array([[100, 200], [150, 250]])
        >>> normalized = normalize_heatmap(heatmap)
        >>> print(normalized)
        [[0.0, 0.666...], [0.333..., 1.0]]

        >>> # Edge case: all values are the same
        >>> constant = np.array([[5, 5], [5, 5]])
        >>> normalized = normalize_heatmap(constant)
        >>> print(normalized)
        [[0. 0.] [0. 0.]]

    Note:
        - Uses min-max normalization: (x - min) / (max - min)
        - Returns zeros if max equals min (constant heatmap)
        - Preserves NaN and inf values in the output
    """
    # Check if there's any variation in the heatmap
    if heatmap.max() > heatmap.min():
        # Apply min-max normalization to scale to [0, 1]
        return (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
    else:
        # If all values are the same, return zeros
        return np.zeros_like(heatmap)


def overlay_heatmap(image, heatmap, alpha=0.5, colormap="hot"):
    """
    Overlay a normalized heatmap on an original image with transparency blending.

    This function creates a visualization by blending a heatmap (e.g., attention map,
    saliency map) with the original image. The heatmap is colored using a matplotlib
    colormap and blended with the image using alpha transparency.

    Args:
        image (PIL.Image): Original RGB image to overlay the heatmap on.
        heatmap (np.ndarray): 2D array of heatmap values. Will be automatically
            normalized to [0, 1] range and resized to match image dimensions.
        alpha (float, optional): Transparency level for heatmap overlay.
            Range: [0, 1] where 0 = invisible, 1 = fully opaque. Defaults to 0.5.
        colormap (str, optional): Matplotlib colormap name for heatmap coloring.
            Common options: 'hot', 'jet', 'viridis', 'coolwarm'. Defaults to 'hot'.

    Returns:
        PIL.Image: RGB image with heatmap overlay, same size as input image.

    Example:
        >>> from PIL import Image
        >>> import numpy as np
        >>> image = Image.open("cat.jpg")
        >>> heatmap = np.random.rand(14, 14)  # Example attention map
        >>> overlay = overlay_heatmap(image, heatmap, alpha=0.6, colormap='jet')
        >>> overlay.save("cat_with_attention.jpg")

    Note:
        - Heatmap is automatically normalized to [0, 1] range
        - Heatmap is resized to match image dimensions using high-quality resampling
        - Supports any matplotlib colormap
        - Returns RGB image (alpha channel is removed after blending)
    """
    # Normalize heatmap to [0, 1] range for consistent coloring
    heatmap = normalize_heatmap(heatmap)

    # Convert heatmap to RGB using the specified matplotlib colormap
    # plt.cm.get_cmap() returns a colormap function
    cmap = plt.get_cmap(colormap)
    # Apply colormap and extract RGB channels (discard alpha)
    heatmap_rgb = (cmap(heatmap)[:, :, :3] * 255).astype(np.uint8)

    # Convert numpy array to PIL Image for resizing
    heatmap_img = Image.fromarray(heatmap_rgb)

    # Resize heatmap to match original image dimensions
    # Uses LANCZOS for high-quality upsampling/downsampling
    heatmap_img = heatmap_img.resize(image.size, Image.Resampling.LANCZOS)

    # Convert both images to RGBA for blending
    original_rgba = image.convert("RGBA")
    heatmap_rgba = heatmap_img.convert("RGBA")

    # Blend images using alpha transparency
    # alpha parameter controls the weight of heatmap vs original image
    blended = Image.blend(original_rgba, heatmap_rgba, alpha)

    # Convert back to RGB (remove alpha channel)
    return blended.convert("RGB")


def create_comparison_figure(original_image, explanation_images, explanation_titles):
    """
    Create a side-by-side comparison figure showing original image and multiple explanations.

    This function is useful for comparing different explainability methods (e.g., attention,
    GradCAM, SHAP) in a single visualization. All images are displayed with equal sizing
    and no axis ticks for a clean presentation.

    Args:
        original_image (PIL.Image): The original input image to display first.
        explanation_images (list): List of PIL Images containing explanation visualizations.
            Each should be the same size as the original image.
        explanation_titles (list): List of strings with titles for each explanation.
            Length must match explanation_images.

    Returns:
        matplotlib.figure.Figure: Figure object with (1 + n) subplots arranged horizontally,
            where n = len(explanation_images).

    Example:
        >>> original = Image.open("cat.jpg")
        >>> attention_map = generate_attention_viz(original)
        >>> gradcam_map = generate_gradcam_viz(original)
        >>>
        >>> fig = create_comparison_figure(
        ...     original,
        ...     [attention_map, gradcam_map],
        ...     ['Attention', 'GradCAM']
        ... )
        >>> fig.savefig('comparison.png')

    Note:
        - Automatically adjusts figure width based on number of images
        - All axes ticks are removed for cleaner visualization
        - Uses tight_layout() to prevent label overlap
    """
    # Calculate number of explanation images
    num_explanations = len(explanation_images)

    # Create figure with horizontal subplot layout
    # Width scales with number of images (4 inches per image)
    fig, axes = plt.subplots(
        1, num_explanations + 1, figsize=(4 * (num_explanations + 1), 4)  # +1 for original image
    )

    # Plot original image in first subplot
    axes[0].imshow(original_image)
    axes[0].set_title("Original Image", fontweight="bold")
    axes[0].axis("off")  # Remove axis ticks and labels

    # Plot each explanation image in subsequent subplots
    for i, (exp_img, title) in enumerate(zip(explanation_images, explanation_titles)):
        axes[i + 1].imshow(exp_img)
        axes[i + 1].set_title(title, fontweight="bold")
        axes[i + 1].axis("off")  # Remove axis ticks and labels

    # Adjust spacing to prevent title/label overlap
    plt.tight_layout()

    return fig


def tensor_to_image(tensor):
    """
    Convert a PyTorch tensor to a PIL Image.

    This utility function handles tensor-to-image conversion with automatic handling
    of batch dimensions, device placement (CPU/GPU), normalization, and channel ordering.
    Useful for visualizing model inputs, intermediate features, or generated images.

    Args:
        tensor (torch.Tensor): Input tensor of shape (C, H, W) or (B, C, H, W) where:
            - B = batch size (will be squeezed if present)
            - C = number of channels (typically 3 for RGB)
            - H = height in pixels
            - W = width in pixels

    Returns:
        PIL.Image: RGB image representation of the tensor.

    Example:
        >>> # Convert model input back to image
        >>> input_tensor = processor(image, return_tensors="pt")['pixel_values']
        >>> recovered_image = tensor_to_image(input_tensor)
        >>> recovered_image.show()

        >>> # Visualize intermediate feature map
        >>> feature_map = model.get_intermediate_features(input_tensor)
        >>> feature_img = tensor_to_image(feature_map)

    Note:
        - Automatically removes batch dimension if present (4D -> 3D)
        - Moves tensor to CPU if on GPU
        - Detaches tensor from computation graph
        - Normalizes values to [0, 1] range if outside this range
        - Converts from (C, H, W) to (H, W, C) format for PIL
        - Scales to [0, 255] and converts to uint8
    """
    # Remove batch dimension if present
    # Changes shape from (1, C, H, W) to (C, H, W)
    if tensor.dim() == 4:
        tensor = tensor.squeeze(0)

    # Move tensor to CPU and detach from computation graph
    # This prevents gradient tracking and allows numpy conversion
    tensor = tensor.cpu().detach()

    # Normalize tensor to [0, 1] range if needed
    # Handles both normalized inputs (e.g., ImageNet normalization)
    # and unnormalized feature maps
    if tensor.min() < 0 or tensor.max() > 1:
        # Apply min-max normalization
        tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())

    # Convert from PyTorch's (C, H, W) to numpy's (H, W, C) format
    numpy_image = tensor.permute(1, 2, 0).numpy()

    # Scale to [0, 255] range and convert to unsigned 8-bit integers
    numpy_image = (numpy_image * 255).astype(np.uint8)

    # Convert numpy array to PIL Image
    return Image.fromarray(numpy_image)


def get_top_predictions_dict(probs, labels, top_k=5):
    """
    Convert top predictions to a dictionary format for Gradio Label component.

    This convenience function formats prediction results for display in Gradio's
    Label component, which requires a dictionary mapping class names to probabilities.

    Args:
        probs (np.ndarray or list): Array or list of probability scores.
            Should be in descending order (highest probability first).
        labels (list): List of class names corresponding to probabilities.
            Must have same length as probs or longer.
        top_k (int, optional): Number of top predictions to include.
            Defaults to 5. If larger than length of probs/labels, uses maximum available.

    Returns:
        dict: Dictionary mapping class names (str) to probability scores (float).
            Keys are class labels, values are probabilities in range [0, 1].

    Example:
        >>> probs = np.array([0.87, 0.08, 0.03, 0.01, 0.01])
        >>> labels = ['tabby cat', 'tiger cat', 'Egyptian cat', 'lynx', 'cougar']
        >>> pred_dict = get_top_predictions_dict(probs, labels, top_k=3)
        >>> print(pred_dict)
        {'tabby cat': 0.87, 'tiger cat': 0.08, 'Egyptian cat': 0.03}

        >>> # Use with Gradio
        >>> import gradio as gr
        >>> output = gr.Label(label="Predictions")
        >>> # Can directly pass pred_dict to this component

    Note:
        - Probabilities are converted to Python float for JSON serialization
        - Only includes top_k predictions (useful for limiting display)
        - Maintains order from input (highest to lowest probability)
    """
    # Create dictionary by zipping labels with probabilities
    # Slicing [:top_k] limits to top_k predictions
    # float() conversion ensures JSON serialization compatibility
    return {label: float(prob) for label, prob in zip(labels[:top_k], probs[:top_k])}