Spaces:
Sleeping
Sleeping
File size: 12,989 Bytes
be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 a01dc02 be5c319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 |
"""
Utility Functions Module
This module provides helper functions for image preprocessing, heatmap manipulation,
visualization, and data conversion used throughout the ViT auditing toolkit.
Author: ViT-XAI-Dashboard Team
License: MIT
"""
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
def preprocess_image(image, target_size=224):
"""
Preprocess an image for Vision Transformer model input.
This function handles loading images from file paths, converts them to RGB format,
and resizes them to the target dimensions required by ViT models.
Args:
image (PIL.Image or str): Input image as a PIL Image object or file path string.
target_size (int, optional): Target square size for resizing. Defaults to 224,
which is the standard input size for most ViT models.
Returns:
PIL.Image: Preprocessed RGB image resized to (target_size, target_size).
Example:
>>> # From file path
>>> img = preprocess_image("path/to/image.jpg")
>>> # From PIL Image
>>> from PIL import Image
>>> img = Image.open("cat.jpg")
>>> processed_img = preprocess_image(img, target_size=384)
Note:
- Grayscale and RGBA images are automatically converted to RGB
- Maintains aspect ratio is not preserved; images are center-cropped and resized
- No normalization is applied; use model processor for that
"""
# If input is a file path string, load the image
if isinstance(image, str):
image = Image.open(image)
# Convert to RGB if necessary (handles grayscale, RGBA, etc.)
if image.mode != "RGB":
image = image.convert("RGB")
# Resize image to target dimensions
# Uses LANCZOS resampling for high-quality downsampling
image = image.resize((target_size, target_size))
return image
def normalize_heatmap(heatmap):
"""
Normalize a heatmap array to the [0, 1] range using min-max scaling.
This function is essential for visualizing heatmaps with consistent color mapping,
regardless of the original value range. It handles edge cases where all values
are identical.
Args:
heatmap (np.ndarray): Input heatmap array of any shape. Can contain any
numeric values (int or float).
Returns:
np.ndarray: Normalized heatmap with values in [0, 1] range, preserving
the original shape and relative differences between values.
Example:
>>> heatmap = np.array([[100, 200], [150, 250]])
>>> normalized = normalize_heatmap(heatmap)
>>> print(normalized)
[[0.0, 0.666...], [0.333..., 1.0]]
>>> # Edge case: all values are the same
>>> constant = np.array([[5, 5], [5, 5]])
>>> normalized = normalize_heatmap(constant)
>>> print(normalized)
[[0. 0.] [0. 0.]]
Note:
- Uses min-max normalization: (x - min) / (max - min)
- Returns zeros if max equals min (constant heatmap)
- Preserves NaN and inf values in the output
"""
# Check if there's any variation in the heatmap
if heatmap.max() > heatmap.min():
# Apply min-max normalization to scale to [0, 1]
return (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
else:
# If all values are the same, return zeros
return np.zeros_like(heatmap)
def overlay_heatmap(image, heatmap, alpha=0.5, colormap="hot"):
"""
Overlay a normalized heatmap on an original image with transparency blending.
This function creates a visualization by blending a heatmap (e.g., attention map,
saliency map) with the original image. The heatmap is colored using a matplotlib
colormap and blended with the image using alpha transparency.
Args:
image (PIL.Image): Original RGB image to overlay the heatmap on.
heatmap (np.ndarray): 2D array of heatmap values. Will be automatically
normalized to [0, 1] range and resized to match image dimensions.
alpha (float, optional): Transparency level for heatmap overlay.
Range: [0, 1] where 0 = invisible, 1 = fully opaque. Defaults to 0.5.
colormap (str, optional): Matplotlib colormap name for heatmap coloring.
Common options: 'hot', 'jet', 'viridis', 'coolwarm'. Defaults to 'hot'.
Returns:
PIL.Image: RGB image with heatmap overlay, same size as input image.
Example:
>>> from PIL import Image
>>> import numpy as np
>>> image = Image.open("cat.jpg")
>>> heatmap = np.random.rand(14, 14) # Example attention map
>>> overlay = overlay_heatmap(image, heatmap, alpha=0.6, colormap='jet')
>>> overlay.save("cat_with_attention.jpg")
Note:
- Heatmap is automatically normalized to [0, 1] range
- Heatmap is resized to match image dimensions using high-quality resampling
- Supports any matplotlib colormap
- Returns RGB image (alpha channel is removed after blending)
"""
# Normalize heatmap to [0, 1] range for consistent coloring
heatmap = normalize_heatmap(heatmap)
# Convert heatmap to RGB using the specified matplotlib colormap
# plt.cm.get_cmap() returns a colormap function
cmap = plt.get_cmap(colormap)
# Apply colormap and extract RGB channels (discard alpha)
heatmap_rgb = (cmap(heatmap)[:, :, :3] * 255).astype(np.uint8)
# Convert numpy array to PIL Image for resizing
heatmap_img = Image.fromarray(heatmap_rgb)
# Resize heatmap to match original image dimensions
# Uses LANCZOS for high-quality upsampling/downsampling
heatmap_img = heatmap_img.resize(image.size, Image.Resampling.LANCZOS)
# Convert both images to RGBA for blending
original_rgba = image.convert("RGBA")
heatmap_rgba = heatmap_img.convert("RGBA")
# Blend images using alpha transparency
# alpha parameter controls the weight of heatmap vs original image
blended = Image.blend(original_rgba, heatmap_rgba, alpha)
# Convert back to RGB (remove alpha channel)
return blended.convert("RGB")
def create_comparison_figure(original_image, explanation_images, explanation_titles):
"""
Create a side-by-side comparison figure showing original image and multiple explanations.
This function is useful for comparing different explainability methods (e.g., attention,
GradCAM, SHAP) in a single visualization. All images are displayed with equal sizing
and no axis ticks for a clean presentation.
Args:
original_image (PIL.Image): The original input image to display first.
explanation_images (list): List of PIL Images containing explanation visualizations.
Each should be the same size as the original image.
explanation_titles (list): List of strings with titles for each explanation.
Length must match explanation_images.
Returns:
matplotlib.figure.Figure: Figure object with (1 + n) subplots arranged horizontally,
where n = len(explanation_images).
Example:
>>> original = Image.open("cat.jpg")
>>> attention_map = generate_attention_viz(original)
>>> gradcam_map = generate_gradcam_viz(original)
>>>
>>> fig = create_comparison_figure(
... original,
... [attention_map, gradcam_map],
... ['Attention', 'GradCAM']
... )
>>> fig.savefig('comparison.png')
Note:
- Automatically adjusts figure width based on number of images
- All axes ticks are removed for cleaner visualization
- Uses tight_layout() to prevent label overlap
"""
# Calculate number of explanation images
num_explanations = len(explanation_images)
# Create figure with horizontal subplot layout
# Width scales with number of images (4 inches per image)
fig, axes = plt.subplots(
1, num_explanations + 1, figsize=(4 * (num_explanations + 1), 4) # +1 for original image
)
# Plot original image in first subplot
axes[0].imshow(original_image)
axes[0].set_title("Original Image", fontweight="bold")
axes[0].axis("off") # Remove axis ticks and labels
# Plot each explanation image in subsequent subplots
for i, (exp_img, title) in enumerate(zip(explanation_images, explanation_titles)):
axes[i + 1].imshow(exp_img)
axes[i + 1].set_title(title, fontweight="bold")
axes[i + 1].axis("off") # Remove axis ticks and labels
# Adjust spacing to prevent title/label overlap
plt.tight_layout()
return fig
def tensor_to_image(tensor):
"""
Convert a PyTorch tensor to a PIL Image.
This utility function handles tensor-to-image conversion with automatic handling
of batch dimensions, device placement (CPU/GPU), normalization, and channel ordering.
Useful for visualizing model inputs, intermediate features, or generated images.
Args:
tensor (torch.Tensor): Input tensor of shape (C, H, W) or (B, C, H, W) where:
- B = batch size (will be squeezed if present)
- C = number of channels (typically 3 for RGB)
- H = height in pixels
- W = width in pixels
Returns:
PIL.Image: RGB image representation of the tensor.
Example:
>>> # Convert model input back to image
>>> input_tensor = processor(image, return_tensors="pt")['pixel_values']
>>> recovered_image = tensor_to_image(input_tensor)
>>> recovered_image.show()
>>> # Visualize intermediate feature map
>>> feature_map = model.get_intermediate_features(input_tensor)
>>> feature_img = tensor_to_image(feature_map)
Note:
- Automatically removes batch dimension if present (4D -> 3D)
- Moves tensor to CPU if on GPU
- Detaches tensor from computation graph
- Normalizes values to [0, 1] range if outside this range
- Converts from (C, H, W) to (H, W, C) format for PIL
- Scales to [0, 255] and converts to uint8
"""
# Remove batch dimension if present
# Changes shape from (1, C, H, W) to (C, H, W)
if tensor.dim() == 4:
tensor = tensor.squeeze(0)
# Move tensor to CPU and detach from computation graph
# This prevents gradient tracking and allows numpy conversion
tensor = tensor.cpu().detach()
# Normalize tensor to [0, 1] range if needed
# Handles both normalized inputs (e.g., ImageNet normalization)
# and unnormalized feature maps
if tensor.min() < 0 or tensor.max() > 1:
# Apply min-max normalization
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
# Convert from PyTorch's (C, H, W) to numpy's (H, W, C) format
numpy_image = tensor.permute(1, 2, 0).numpy()
# Scale to [0, 255] range and convert to unsigned 8-bit integers
numpy_image = (numpy_image * 255).astype(np.uint8)
# Convert numpy array to PIL Image
return Image.fromarray(numpy_image)
def get_top_predictions_dict(probs, labels, top_k=5):
"""
Convert top predictions to a dictionary format for Gradio Label component.
This convenience function formats prediction results for display in Gradio's
Label component, which requires a dictionary mapping class names to probabilities.
Args:
probs (np.ndarray or list): Array or list of probability scores.
Should be in descending order (highest probability first).
labels (list): List of class names corresponding to probabilities.
Must have same length as probs or longer.
top_k (int, optional): Number of top predictions to include.
Defaults to 5. If larger than length of probs/labels, uses maximum available.
Returns:
dict: Dictionary mapping class names (str) to probability scores (float).
Keys are class labels, values are probabilities in range [0, 1].
Example:
>>> probs = np.array([0.87, 0.08, 0.03, 0.01, 0.01])
>>> labels = ['tabby cat', 'tiger cat', 'Egyptian cat', 'lynx', 'cougar']
>>> pred_dict = get_top_predictions_dict(probs, labels, top_k=3)
>>> print(pred_dict)
{'tabby cat': 0.87, 'tiger cat': 0.08, 'Egyptian cat': 0.03}
>>> # Use with Gradio
>>> import gradio as gr
>>> output = gr.Label(label="Predictions")
>>> # Can directly pass pred_dict to this component
Note:
- Probabilities are converted to Python float for JSON serialization
- Only includes top_k predictions (useful for limiting display)
- Maintains order from input (highest to lowest probability)
"""
# Create dictionary by zipping labels with probabilities
# Slicing [:top_k] limits to top_k predictions
# float() conversion ensures JSON serialization compatibility
return {label: float(prob) for label, prob in zip(labels[:top_k], probs[:top_k])}
|