File size: 4,150 Bytes
a01dc02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# src/utils.py

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch

def preprocess_image(image, target_size=224):
    """
    Preprocess image for ViT model.
    
    Args:
        image: PIL Image or file path
        target_size: Target size for resizing
    
    Returns:
        PIL.Image: Preprocessed image
    """
    if isinstance(image, str):
        # If it's a file path, load the image
        image = Image.open(image)
    
    # Convert to RGB if necessary
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Resize image
    image = image.resize((target_size, target_size))
    
    return image

def normalize_heatmap(heatmap):
    """
    Normalize heatmap to [0, 1] range.
    
    Args:
        heatmap: numpy array of heatmap values
    
    Returns:
        numpy.array: Normalized heatmap
    """
    if heatmap.max() > heatmap.min():
        return (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
    else:
        return np.zeros_like(heatmap)

def overlay_heatmap(image, heatmap, alpha=0.5, colormap='hot'):
    """
    Overlay heatmap on original image.
    
    Args:
        image: PIL Image
        heatmap: numpy array of heatmap values
        alpha: Transparency for heatmap overlay
        colormap: Matplotlib colormap name
    
    Returns:
        PIL.Image: Image with heatmap overlay
    """
    # Normalize heatmap
    heatmap = normalize_heatmap(heatmap)
    
    # Convert heatmap to RGB using colormap
    cmap = plt.get_cmap(colormap)
    heatmap_rgb = (cmap(heatmap)[:, :, :3] * 255).astype(np.uint8)
    
    # Resize heatmap to match image size
    heatmap_img = Image.fromarray(heatmap_rgb)
    heatmap_img = heatmap_img.resize(image.size, Image.Resampling.LANCZOS)
    
    # Blend images
    original_rgba = image.convert('RGBA')
    heatmap_rgba = heatmap_img.convert('RGBA')
    blended = Image.blend(original_rgba, heatmap_rgba, alpha)
    
    return blended.convert('RGB')

def create_comparison_figure(original_image, explanation_images, explanation_titles):
    """
    Create a comparison figure showing original image and multiple explanations.
    
    Args:
        original_image: PIL Image
        explanation_images: List of explanation images
        explanation_titles: List of titles for each explanation
    
    Returns:
        matplotlib.figure.Figure: Comparison figure
    """
    num_explanations = len(explanation_images)
    fig, axes = plt.subplots(1, num_explanations + 1, figsize=(4 * (num_explanations + 1), 4))
    
    # Plot original image
    axes[0].imshow(original_image)
    axes[0].set_title('Original Image', fontweight='bold')
    axes[0].axis('off')
    
    # Plot explanations
    for i, (exp_img, title) in enumerate(zip(explanation_images, explanation_titles)):
        axes[i + 1].imshow(exp_img)
        axes[i + 1].set_title(title, fontweight='bold')
        axes[i + 1].axis('off')
    
    plt.tight_layout()
    return fig

def tensor_to_image(tensor):
    """
    Convert PyTorch tensor to PIL Image.
    
    Args:
        tensor: PyTorch tensor of shape (C, H, W) or (B, C, H, W)
    
    Returns:
        PIL.Image: Converted image
    """
    if tensor.dim() == 4:
        tensor = tensor.squeeze(0)
    
    # Denormalize if needed and convert to numpy
    tensor = tensor.cpu().detach()
    if tensor.min() < 0 or tensor.max() > 1:
        # Assume it's normalized, denormalize to [0, 1]
        tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
    
    numpy_image = tensor.permute(1, 2, 0).numpy()
    numpy_image = (numpy_image * 255).astype(np.uint8)
    
    return Image.fromarray(numpy_image)

def get_top_predictions_dict(probs, labels, top_k=5):
    """
    Convert top predictions to dictionary for Gradio Label component.
    
    Args:
        probs: Array of probabilities
        labels: List of label names
        top_k: Number of top predictions to include
    
    Returns:
        dict: Dictionary of {label: probability} for top-k predictions
    """
    return {label: float(prob) for label, prob in zip(labels[:top_k], probs[:top_k])}