""" Utility Functions Module This module provides helper functions for image preprocessing, heatmap manipulation, visualization, and data conversion used throughout the ViT auditing toolkit. Author: ViT-XAI-Dashboard Team License: MIT """ import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image def preprocess_image(image, target_size=224): """ Preprocess an image for Vision Transformer model input. This function handles loading images from file paths, converts them to RGB format, and resizes them to the target dimensions required by ViT models. Args: image (PIL.Image or str): Input image as a PIL Image object or file path string. target_size (int, optional): Target square size for resizing. Defaults to 224, which is the standard input size for most ViT models. Returns: PIL.Image: Preprocessed RGB image resized to (target_size, target_size). Example: >>> # From file path >>> img = preprocess_image("path/to/image.jpg") >>> # From PIL Image >>> from PIL import Image >>> img = Image.open("cat.jpg") >>> processed_img = preprocess_image(img, target_size=384) Note: - Grayscale and RGBA images are automatically converted to RGB - Maintains aspect ratio is not preserved; images are center-cropped and resized - No normalization is applied; use model processor for that """ # If input is a file path string, load the image if isinstance(image, str): image = Image.open(image) # Convert to RGB if necessary (handles grayscale, RGBA, etc.) if image.mode != "RGB": image = image.convert("RGB") # Resize image to target dimensions # Uses LANCZOS resampling for high-quality downsampling image = image.resize((target_size, target_size)) return image def normalize_heatmap(heatmap): """ Normalize a heatmap array to the [0, 1] range using min-max scaling. This function is essential for visualizing heatmaps with consistent color mapping, regardless of the original value range. It handles edge cases where all values are identical. Args: heatmap (np.ndarray): Input heatmap array of any shape. Can contain any numeric values (int or float). Returns: np.ndarray: Normalized heatmap with values in [0, 1] range, preserving the original shape and relative differences between values. Example: >>> heatmap = np.array([[100, 200], [150, 250]]) >>> normalized = normalize_heatmap(heatmap) >>> print(normalized) [[0.0, 0.666...], [0.333..., 1.0]] >>> # Edge case: all values are the same >>> constant = np.array([[5, 5], [5, 5]]) >>> normalized = normalize_heatmap(constant) >>> print(normalized) [[0. 0.] [0. 0.]] Note: - Uses min-max normalization: (x - min) / (max - min) - Returns zeros if max equals min (constant heatmap) - Preserves NaN and inf values in the output """ # Check if there's any variation in the heatmap if heatmap.max() > heatmap.min(): # Apply min-max normalization to scale to [0, 1] return (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) else: # If all values are the same, return zeros return np.zeros_like(heatmap) def overlay_heatmap(image, heatmap, alpha=0.5, colormap="hot"): """ Overlay a normalized heatmap on an original image with transparency blending. This function creates a visualization by blending a heatmap (e.g., attention map, saliency map) with the original image. The heatmap is colored using a matplotlib colormap and blended with the image using alpha transparency. Args: image (PIL.Image): Original RGB image to overlay the heatmap on. heatmap (np.ndarray): 2D array of heatmap values. Will be automatically normalized to [0, 1] range and resized to match image dimensions. alpha (float, optional): Transparency level for heatmap overlay. Range: [0, 1] where 0 = invisible, 1 = fully opaque. Defaults to 0.5. colormap (str, optional): Matplotlib colormap name for heatmap coloring. Common options: 'hot', 'jet', 'viridis', 'coolwarm'. Defaults to 'hot'. Returns: PIL.Image: RGB image with heatmap overlay, same size as input image. Example: >>> from PIL import Image >>> import numpy as np >>> image = Image.open("cat.jpg") >>> heatmap = np.random.rand(14, 14) # Example attention map >>> overlay = overlay_heatmap(image, heatmap, alpha=0.6, colormap='jet') >>> overlay.save("cat_with_attention.jpg") Note: - Heatmap is automatically normalized to [0, 1] range - Heatmap is resized to match image dimensions using high-quality resampling - Supports any matplotlib colormap - Returns RGB image (alpha channel is removed after blending) """ # Normalize heatmap to [0, 1] range for consistent coloring heatmap = normalize_heatmap(heatmap) # Convert heatmap to RGB using the specified matplotlib colormap # plt.cm.get_cmap() returns a colormap function cmap = plt.get_cmap(colormap) # Apply colormap and extract RGB channels (discard alpha) heatmap_rgb = (cmap(heatmap)[:, :, :3] * 255).astype(np.uint8) # Convert numpy array to PIL Image for resizing heatmap_img = Image.fromarray(heatmap_rgb) # Resize heatmap to match original image dimensions # Uses LANCZOS for high-quality upsampling/downsampling heatmap_img = heatmap_img.resize(image.size, Image.Resampling.LANCZOS) # Convert both images to RGBA for blending original_rgba = image.convert("RGBA") heatmap_rgba = heatmap_img.convert("RGBA") # Blend images using alpha transparency # alpha parameter controls the weight of heatmap vs original image blended = Image.blend(original_rgba, heatmap_rgba, alpha) # Convert back to RGB (remove alpha channel) return blended.convert("RGB") def create_comparison_figure(original_image, explanation_images, explanation_titles): """ Create a side-by-side comparison figure showing original image and multiple explanations. This function is useful for comparing different explainability methods (e.g., attention, GradCAM, SHAP) in a single visualization. All images are displayed with equal sizing and no axis ticks for a clean presentation. Args: original_image (PIL.Image): The original input image to display first. explanation_images (list): List of PIL Images containing explanation visualizations. Each should be the same size as the original image. explanation_titles (list): List of strings with titles for each explanation. Length must match explanation_images. Returns: matplotlib.figure.Figure: Figure object with (1 + n) subplots arranged horizontally, where n = len(explanation_images). Example: >>> original = Image.open("cat.jpg") >>> attention_map = generate_attention_viz(original) >>> gradcam_map = generate_gradcam_viz(original) >>> >>> fig = create_comparison_figure( ... original, ... [attention_map, gradcam_map], ... ['Attention', 'GradCAM'] ... ) >>> fig.savefig('comparison.png') Note: - Automatically adjusts figure width based on number of images - All axes ticks are removed for cleaner visualization - Uses tight_layout() to prevent label overlap """ # Calculate number of explanation images num_explanations = len(explanation_images) # Create figure with horizontal subplot layout # Width scales with number of images (4 inches per image) fig, axes = plt.subplots( 1, num_explanations + 1, figsize=(4 * (num_explanations + 1), 4) # +1 for original image ) # Plot original image in first subplot axes[0].imshow(original_image) axes[0].set_title("Original Image", fontweight="bold") axes[0].axis("off") # Remove axis ticks and labels # Plot each explanation image in subsequent subplots for i, (exp_img, title) in enumerate(zip(explanation_images, explanation_titles)): axes[i + 1].imshow(exp_img) axes[i + 1].set_title(title, fontweight="bold") axes[i + 1].axis("off") # Remove axis ticks and labels # Adjust spacing to prevent title/label overlap plt.tight_layout() return fig def tensor_to_image(tensor): """ Convert a PyTorch tensor to a PIL Image. This utility function handles tensor-to-image conversion with automatic handling of batch dimensions, device placement (CPU/GPU), normalization, and channel ordering. Useful for visualizing model inputs, intermediate features, or generated images. Args: tensor (torch.Tensor): Input tensor of shape (C, H, W) or (B, C, H, W) where: - B = batch size (will be squeezed if present) - C = number of channels (typically 3 for RGB) - H = height in pixels - W = width in pixels Returns: PIL.Image: RGB image representation of the tensor. Example: >>> # Convert model input back to image >>> input_tensor = processor(image, return_tensors="pt")['pixel_values'] >>> recovered_image = tensor_to_image(input_tensor) >>> recovered_image.show() >>> # Visualize intermediate feature map >>> feature_map = model.get_intermediate_features(input_tensor) >>> feature_img = tensor_to_image(feature_map) Note: - Automatically removes batch dimension if present (4D -> 3D) - Moves tensor to CPU if on GPU - Detaches tensor from computation graph - Normalizes values to [0, 1] range if outside this range - Converts from (C, H, W) to (H, W, C) format for PIL - Scales to [0, 255] and converts to uint8 """ # Remove batch dimension if present # Changes shape from (1, C, H, W) to (C, H, W) if tensor.dim() == 4: tensor = tensor.squeeze(0) # Move tensor to CPU and detach from computation graph # This prevents gradient tracking and allows numpy conversion tensor = tensor.cpu().detach() # Normalize tensor to [0, 1] range if needed # Handles both normalized inputs (e.g., ImageNet normalization) # and unnormalized feature maps if tensor.min() < 0 or tensor.max() > 1: # Apply min-max normalization tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) # Convert from PyTorch's (C, H, W) to numpy's (H, W, C) format numpy_image = tensor.permute(1, 2, 0).numpy() # Scale to [0, 255] range and convert to unsigned 8-bit integers numpy_image = (numpy_image * 255).astype(np.uint8) # Convert numpy array to PIL Image return Image.fromarray(numpy_image) def get_top_predictions_dict(probs, labels, top_k=5): """ Convert top predictions to a dictionary format for Gradio Label component. This convenience function formats prediction results for display in Gradio's Label component, which requires a dictionary mapping class names to probabilities. Args: probs (np.ndarray or list): Array or list of probability scores. Should be in descending order (highest probability first). labels (list): List of class names corresponding to probabilities. Must have same length as probs or longer. top_k (int, optional): Number of top predictions to include. Defaults to 5. If larger than length of probs/labels, uses maximum available. Returns: dict: Dictionary mapping class names (str) to probability scores (float). Keys are class labels, values are probabilities in range [0, 1]. Example: >>> probs = np.array([0.87, 0.08, 0.03, 0.01, 0.01]) >>> labels = ['tabby cat', 'tiger cat', 'Egyptian cat', 'lynx', 'cougar'] >>> pred_dict = get_top_predictions_dict(probs, labels, top_k=3) >>> print(pred_dict) {'tabby cat': 0.87, 'tiger cat': 0.08, 'Egyptian cat': 0.03} >>> # Use with Gradio >>> import gradio as gr >>> output = gr.Label(label="Predictions") >>> # Can directly pass pred_dict to this component Note: - Probabilities are converted to Python float for JSON serialization - Only includes top_k predictions (useful for limiting display) - Maintains order from input (highest to lowest probability) """ # Create dictionary by zipping labels with probabilities # Slicing [:top_k] limits to top_k predictions # float() conversion ensures JSON serialization compatibility return {label: float(prob) for label, prob in zip(labels[:top_k], probs[:top_k])}