Spaces:
Sleeping
Sleeping
| # src/utils.py | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import torch | |
| def preprocess_image(image, target_size=224): | |
| """ | |
| Preprocess image for ViT model. | |
| Args: | |
| image: PIL Image or file path | |
| target_size: Target size for resizing | |
| Returns: | |
| PIL.Image: Preprocessed image | |
| """ | |
| if isinstance(image, str): | |
| # If it's a file path, load the image | |
| image = Image.open(image) | |
| # Convert to RGB if necessary | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Resize image | |
| image = image.resize((target_size, target_size)) | |
| return image | |
| def normalize_heatmap(heatmap): | |
| """ | |
| Normalize heatmap to [0, 1] range. | |
| Args: | |
| heatmap: numpy array of heatmap values | |
| Returns: | |
| numpy.array: Normalized heatmap | |
| """ | |
| if heatmap.max() > heatmap.min(): | |
| return (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) | |
| else: | |
| return np.zeros_like(heatmap) | |
| def overlay_heatmap(image, heatmap, alpha=0.5, colormap='hot'): | |
| """ | |
| Overlay heatmap on original image. | |
| Args: | |
| image: PIL Image | |
| heatmap: numpy array of heatmap values | |
| alpha: Transparency for heatmap overlay | |
| colormap: Matplotlib colormap name | |
| Returns: | |
| PIL.Image: Image with heatmap overlay | |
| """ | |
| # Normalize heatmap | |
| heatmap = normalize_heatmap(heatmap) | |
| # Convert heatmap to RGB using colormap | |
| cmap = plt.get_cmap(colormap) | |
| heatmap_rgb = (cmap(heatmap)[:, :, :3] * 255).astype(np.uint8) | |
| # Resize heatmap to match image size | |
| heatmap_img = Image.fromarray(heatmap_rgb) | |
| heatmap_img = heatmap_img.resize(image.size, Image.Resampling.LANCZOS) | |
| # Blend images | |
| original_rgba = image.convert('RGBA') | |
| heatmap_rgba = heatmap_img.convert('RGBA') | |
| blended = Image.blend(original_rgba, heatmap_rgba, alpha) | |
| return blended.convert('RGB') | |
| def create_comparison_figure(original_image, explanation_images, explanation_titles): | |
| """ | |
| Create a comparison figure showing original image and multiple explanations. | |
| Args: | |
| original_image: PIL Image | |
| explanation_images: List of explanation images | |
| explanation_titles: List of titles for each explanation | |
| Returns: | |
| matplotlib.figure.Figure: Comparison figure | |
| """ | |
| num_explanations = len(explanation_images) | |
| fig, axes = plt.subplots(1, num_explanations + 1, figsize=(4 * (num_explanations + 1), 4)) | |
| # Plot original image | |
| axes[0].imshow(original_image) | |
| axes[0].set_title('Original Image', fontweight='bold') | |
| axes[0].axis('off') | |
| # Plot explanations | |
| for i, (exp_img, title) in enumerate(zip(explanation_images, explanation_titles)): | |
| axes[i + 1].imshow(exp_img) | |
| axes[i + 1].set_title(title, fontweight='bold') | |
| axes[i + 1].axis('off') | |
| plt.tight_layout() | |
| return fig | |
| def tensor_to_image(tensor): | |
| """ | |
| Convert PyTorch tensor to PIL Image. | |
| Args: | |
| tensor: PyTorch tensor of shape (C, H, W) or (B, C, H, W) | |
| Returns: | |
| PIL.Image: Converted image | |
| """ | |
| if tensor.dim() == 4: | |
| tensor = tensor.squeeze(0) | |
| # Denormalize if needed and convert to numpy | |
| tensor = tensor.cpu().detach() | |
| if tensor.min() < 0 or tensor.max() > 1: | |
| # Assume it's normalized, denormalize to [0, 1] | |
| tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) | |
| numpy_image = tensor.permute(1, 2, 0).numpy() | |
| numpy_image = (numpy_image * 255).astype(np.uint8) | |
| return Image.fromarray(numpy_image) | |
| def get_top_predictions_dict(probs, labels, top_k=5): | |
| """ | |
| Convert top predictions to dictionary for Gradio Label component. | |
| Args: | |
| probs: Array of probabilities | |
| labels: List of label names | |
| top_k: Number of top predictions to include | |
| Returns: | |
| dict: Dictionary of {label: probability} for top-k predictions | |
| """ | |
| return {label: float(prob) for label, prob in zip(labels[:top_k], probs[:top_k])} |