Spaces:
Sleeping
Sleeping
File size: 4,150 Bytes
a01dc02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# src/utils.py
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
def preprocess_image(image, target_size=224):
"""
Preprocess image for ViT model.
Args:
image: PIL Image or file path
target_size: Target size for resizing
Returns:
PIL.Image: Preprocessed image
"""
if isinstance(image, str):
# If it's a file path, load the image
image = Image.open(image)
# Convert to RGB if necessary
if image.mode != 'RGB':
image = image.convert('RGB')
# Resize image
image = image.resize((target_size, target_size))
return image
def normalize_heatmap(heatmap):
"""
Normalize heatmap to [0, 1] range.
Args:
heatmap: numpy array of heatmap values
Returns:
numpy.array: Normalized heatmap
"""
if heatmap.max() > heatmap.min():
return (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
else:
return np.zeros_like(heatmap)
def overlay_heatmap(image, heatmap, alpha=0.5, colormap='hot'):
"""
Overlay heatmap on original image.
Args:
image: PIL Image
heatmap: numpy array of heatmap values
alpha: Transparency for heatmap overlay
colormap: Matplotlib colormap name
Returns:
PIL.Image: Image with heatmap overlay
"""
# Normalize heatmap
heatmap = normalize_heatmap(heatmap)
# Convert heatmap to RGB using colormap
cmap = plt.get_cmap(colormap)
heatmap_rgb = (cmap(heatmap)[:, :, :3] * 255).astype(np.uint8)
# Resize heatmap to match image size
heatmap_img = Image.fromarray(heatmap_rgb)
heatmap_img = heatmap_img.resize(image.size, Image.Resampling.LANCZOS)
# Blend images
original_rgba = image.convert('RGBA')
heatmap_rgba = heatmap_img.convert('RGBA')
blended = Image.blend(original_rgba, heatmap_rgba, alpha)
return blended.convert('RGB')
def create_comparison_figure(original_image, explanation_images, explanation_titles):
"""
Create a comparison figure showing original image and multiple explanations.
Args:
original_image: PIL Image
explanation_images: List of explanation images
explanation_titles: List of titles for each explanation
Returns:
matplotlib.figure.Figure: Comparison figure
"""
num_explanations = len(explanation_images)
fig, axes = plt.subplots(1, num_explanations + 1, figsize=(4 * (num_explanations + 1), 4))
# Plot original image
axes[0].imshow(original_image)
axes[0].set_title('Original Image', fontweight='bold')
axes[0].axis('off')
# Plot explanations
for i, (exp_img, title) in enumerate(zip(explanation_images, explanation_titles)):
axes[i + 1].imshow(exp_img)
axes[i + 1].set_title(title, fontweight='bold')
axes[i + 1].axis('off')
plt.tight_layout()
return fig
def tensor_to_image(tensor):
"""
Convert PyTorch tensor to PIL Image.
Args:
tensor: PyTorch tensor of shape (C, H, W) or (B, C, H, W)
Returns:
PIL.Image: Converted image
"""
if tensor.dim() == 4:
tensor = tensor.squeeze(0)
# Denormalize if needed and convert to numpy
tensor = tensor.cpu().detach()
if tensor.min() < 0 or tensor.max() > 1:
# Assume it's normalized, denormalize to [0, 1]
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
numpy_image = tensor.permute(1, 2, 0).numpy()
numpy_image = (numpy_image * 255).astype(np.uint8)
return Image.fromarray(numpy_image)
def get_top_predictions_dict(probs, labels, top_k=5):
"""
Convert top predictions to dictionary for Gradio Label component.
Args:
probs: Array of probabilities
labels: List of label names
top_k: Number of top predictions to include
Returns:
dict: Dictionary of {label: probability} for top-k predictions
"""
return {label: float(prob) for label, prob in zip(labels[:top_k], probs[:top_k])} |