Dyuti Dasmahapatra
complete Phase 1 - core ViT auditing toolkit implementation
a01dc02
raw
history blame
4.15 kB
# src/utils.py
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
def preprocess_image(image, target_size=224):
"""
Preprocess image for ViT model.
Args:
image: PIL Image or file path
target_size: Target size for resizing
Returns:
PIL.Image: Preprocessed image
"""
if isinstance(image, str):
# If it's a file path, load the image
image = Image.open(image)
# Convert to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
# Resize image
image = image.resize((target_size, target_size))
return image
def normalize_heatmap(heatmap):
"""
Normalize heatmap to [0, 1] range.
Args:
heatmap: numpy array of heatmap values
Returns:
numpy.array: Normalized heatmap
"""
if heatmap.max() > heatmap.min():
return (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
else:
return np.zeros_like(heatmap)
def overlay_heatmap(image, heatmap, alpha=0.5, colormap='hot'):
"""
Overlay heatmap on original image.
Args:
image: PIL Image
heatmap: numpy array of heatmap values
alpha: Transparency for heatmap overlay
colormap: Matplotlib colormap name
Returns:
PIL.Image: Image with heatmap overlay
"""
# Normalize heatmap
heatmap = normalize_heatmap(heatmap)
# Convert heatmap to RGB using colormap
cmap = plt.get_cmap(colormap)
heatmap_rgb = (cmap(heatmap)[:, :, :3] * 255).astype(np.uint8)
# Resize heatmap to match image size
heatmap_img = Image.fromarray(heatmap_rgb)
heatmap_img = heatmap_img.resize(image.size, Image.Resampling.LANCZOS)
# Blend images
original_rgba = image.convert('RGBA')
heatmap_rgba = heatmap_img.convert('RGBA')
blended = Image.blend(original_rgba, heatmap_rgba, alpha)
return blended.convert('RGB')
def create_comparison_figure(original_image, explanation_images, explanation_titles):
"""
Create a comparison figure showing original image and multiple explanations.
Args:
original_image: PIL Image
explanation_images: List of explanation images
explanation_titles: List of titles for each explanation
Returns:
matplotlib.figure.Figure: Comparison figure
"""
num_explanations = len(explanation_images)
fig, axes = plt.subplots(1, num_explanations + 1, figsize=(4 * (num_explanations + 1), 4))
# Plot original image
axes[0].imshow(original_image)
axes[0].set_title('Original Image', fontweight='bold')
axes[0].axis('off')
# Plot explanations
for i, (exp_img, title) in enumerate(zip(explanation_images, explanation_titles)):
axes[i + 1].imshow(exp_img)
axes[i + 1].set_title(title, fontweight='bold')
axes[i + 1].axis('off')
plt.tight_layout()
return fig
def tensor_to_image(tensor):
"""
Convert PyTorch tensor to PIL Image.
Args:
tensor: PyTorch tensor of shape (C, H, W) or (B, C, H, W)
Returns:
PIL.Image: Converted image
"""
if tensor.dim() == 4:
tensor = tensor.squeeze(0)
# Denormalize if needed and convert to numpy
tensor = tensor.cpu().detach()
if tensor.min() < 0 or tensor.max() > 1:
# Assume it's normalized, denormalize to [0, 1]
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
numpy_image = tensor.permute(1, 2, 0).numpy()
numpy_image = (numpy_image * 255).astype(np.uint8)
return Image.fromarray(numpy_image)
def get_top_predictions_dict(probs, labels, top_k=5):
"""
Convert top predictions to dictionary for Gradio Label component.
Args:
probs: Array of probabilities
labels: List of label names
top_k: Number of top predictions to include
Returns:
dict: Dictionary of {label: probability} for top-k predictions
"""
return {label: float(prob) for label, prob in zip(labels[:top_k], probs[:top_k])}