File size: 6,656 Bytes
be5c319
 
a01dc02
be5c319
 
 
 
 
 
 
 
 
a01dc02
 
 
be5c319
a01dc02
 
 
be5c319
 
 
 
 
 
a01dc02
be5c319
 
 
 
 
a01dc02
be5c319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a01dc02
 
be5c319
 
a01dc02
be5c319
 
 
a01dc02
be5c319
 
a01dc02
be5c319
 
a01dc02
 
be5c319
 
 
 
 
 
 
 
a01dc02
be5c319
 
a01dc02
 
be5c319
0101a8b
 
 
 
 
 
 
 
 
be5c319
a01dc02
be5c319
a01dc02
be5c319
a01dc02
 
be5c319
a01dc02
 
be5c319
 
 
 
 
 
a01dc02
be5c319
 
 
 
 
a01dc02
be5c319
 
 
 
 
 
 
 
 
 
 
 
 
 
a01dc02
be5c319
a01dc02
be5c319
a01dc02
be5c319
a01dc02
be5c319
 
 
a01dc02
 
be5c319
 
 
 
 
 
a01dc02
be5c319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a01dc02
be5c319
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
Predictor Module

This module handles image classification predictions using Vision Transformer models.
It provides functions for making predictions and creating visualization plots of results.

Author: ViT-XAI-Dashboard Team
License: MIT
"""

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image


def predict_image(image, model, processor, top_k=5):
    """
    Perform inference on an image and return top-k predicted classes with probabilities.

    This function takes a PIL Image, preprocesses it using the model's processor,
    performs a forward pass through the model, and returns the top-k most likely
    class predictions along with their confidence scores.

    Args:
        image (PIL.Image): Input image to classify. Should be in RGB format.
        model (ViTForImageClassification): Pre-trained ViT model for inference.
        processor (ViTImageProcessor): Image processor for preprocessing.
        top_k (int, optional): Number of top predictions to return. Defaults to 5.

    Returns:
        tuple: A tuple containing three elements:
            - top_probs (np.ndarray): Array of shape (top_k,) with confidence scores
            - top_indices (np.ndarray): Array of shape (top_k,) with class indices
            - top_labels (list): List of length top_k with human-readable class names

    Raises:
        Exception: If prediction fails due to invalid image, model issues, or memory errors.

    Example:
        >>> from PIL import Image
        >>> image = Image.open("cat.jpg")
        >>> probs, indices, labels = predict_image(image, model, processor, top_k=3)
        >>> print(f"Top prediction: {labels[0]} with {probs[0]:.2%} confidence")
        Top prediction: tabby cat with 87.34% confidence

    Note:
        - Inference is performed with torch.no_grad() for efficiency
        - Automatically handles device placement (CPU/GPU)
        - Applies softmax to convert logits to probabilities
    """
    try:
        # Get the device from the model parameters
        # This ensures inputs are moved to the same device as model (CPU or GPU)
        device = next(model.parameters()).device

        # Preprocess the image using the ViT processor
        # This handles resizing, normalization, and conversion to tensors
        inputs = processor(images=image, return_tensors="pt")

        # Move all input tensors to the same device as the model
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Perform inference without gradient computation (saves memory and speeds up)
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits  # Raw model outputs before softmax

        # Apply softmax to convert logits to probabilities
        # dim=-1 applies softmax across the class dimension
        probabilities = F.softmax(logits, dim=-1)[0]  # [0] removes batch dimension

        # Get the top-k highest probability predictions
        # Returns both values (probabilities) and indices (class IDs)
        top_probs, top_indices = torch.topk(probabilities, top_k)

        # Convert PyTorch tensors to NumPy arrays for easier handling
        top_probs = top_probs.cpu().numpy()
        top_indices = top_indices.cpu().numpy()

        # Convert class indices to human-readable labels using model's label mapping when available
        id2label = None
        if hasattr(model, "config") and hasattr(model.config, "id2label"):
            id2label = model.config.id2label

        top_labels = [
            (id2label.get(int(idx), f"class_{int(idx)}") if isinstance(id2label, dict) else f"class_{int(idx)}")
            for idx in top_indices
        ]

        return top_probs, top_indices, top_labels

    except Exception as e:
        print(f"❌ Error during prediction: {str(e)}")
        raise


def create_prediction_plot(probs, labels):
    """
    Create a professional horizontal bar chart visualizing top predictions.

    This function generates a matplotlib figure with a horizontal bar chart showing
    the model's top predictions along with their confidence scores. The chart includes
    percentage labels on each bar and a clean, minimalist design.

    Args:
        probs (np.ndarray or list): Array of probability scores for each class.
            Should be in descending order (highest probability first).
        labels (list): List of human-readable class names corresponding to probabilities.
            Length must match probs.

    Returns:
        matplotlib.figure.Figure: A matplotlib Figure object containing the bar chart.
            Can be displayed with fig.show() or saved with fig.savefig().

    Example:
        >>> probs = np.array([0.87, 0.08, 0.03, 0.01, 0.01])
        >>> labels = ['tabby cat', 'tiger cat', 'Egyptian cat', 'lynx', 'cougar']
        >>> fig = create_prediction_plot(probs, labels)
        >>> fig.savefig('predictions.png')

    Note:
        - Uses horizontal bars for better label readability
        - Automatically adds percentage labels on each bar
        - Includes subtle grid lines for easier value reading
        - X-axis is scaled to provide padding for percentage labels
    """
    # Create figure and axis with specified size
    fig, ax = plt.subplots(figsize=(8, 4))

    # Create horizontal bar chart
    # y_pos represents the vertical position of each bar
    y_pos = np.arange(len(labels))
    bars = ax.barh(y_pos, probs, color="skyblue", alpha=0.8)

    # Set y-axis ticks and labels
    ax.set_yticks(y_pos)
    ax.set_yticklabels(labels, fontsize=10)

    # Set axis labels and title
    ax.set_xlabel("Confidence", fontsize=12)
    ax.set_title("Top Predictions", fontsize=14, fontweight="bold")

    # Add probability percentage text on each bar
    for i, (bar, prob) in enumerate(zip(bars, probs)):
        width = bar.get_width()  # Get the bar length (probability value)
        # Place text slightly to the right of the bar end
        ax.text(
            width + 0.01,  # X position (slightly right of bar)
            bar.get_y() + bar.get_height() / 2,  # Y position (center of bar)
            f"{prob:.2%}",  # Format as percentage with 2 decimal places
            va="center",  # Vertical alignment
            fontsize=9,
        )

    # Set x-axis limits with padding for percentage labels
    # 1.15 multiplier adds 15% padding to the right
    ax.set_xlim(0, max(probs) * 1.15)

    # Add subtle grid lines for easier value reading
    ax.grid(axis="x", alpha=0.3, linestyle="--")

    # Adjust layout to prevent label cutoff
    plt.tight_layout()

    return fig