ViT-Auditing-Toolkit / src /predictor.py
Dyuti Dasmahapatra
complete Phase 1 - core ViT auditing toolkit implementation
a01dc02
raw
history blame
2.81 kB
# src/predictor.py
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
def predict_image(image, model, processor, top_k=5):
"""
Perform inference on an image and return top-k predictions.
Args:
image (PIL.Image): Input image to classify.
model: Loaded ViT model.
processor: Loaded ViT processor.
top_k (int): Number of top predictions to return.
Returns:
tuple: (top_probs, top_indices, top_labels) - Probabilities, class indices, and label names.
"""
try:
# Get the device from the model
device = next(model.parameters()).device
# Preprocess the image - note: current processors return pixel_values
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Apply softmax to get probabilities
probabilities = F.softmax(logits, dim=-1)[0]
# Get top-k predictions
top_probs, top_indices = torch.topk(probabilities, top_k)
# Convert to Python lists and numpy arrays
top_probs = top_probs.cpu().numpy()
top_indices = top_indices.cpu().numpy()
# Get human-readable labels
top_labels = [model.config.id2label[idx] for idx in top_indices]
return top_probs, top_indices, top_labels
except Exception as e:
print(f"Error during prediction: {str(e)}")
raise
def create_prediction_plot(probs, labels):
"""
Create a clean, professional bar chart for top predictions.
Args:
probs (np.array): Array of probabilities.
labels (list): List of label names.
Returns:
matplotlib.figure.Figure: The generated plot figure.
"""
fig, ax = plt.subplots(figsize=(8, 4))
# Create horizontal bar chart
y_pos = np.arange(len(labels))
bars = ax.barh(y_pos, probs, color='skyblue', alpha=0.8)
ax.set_yticks(y_pos)
ax.set_yticklabels(labels, fontsize=10)
ax.set_xlabel('Confidence', fontsize=12)
ax.set_title('Top Predictions', fontsize=14, fontweight='bold')
# Add probability text on bars
for i, (bar, prob) in enumerate(zip(bars, probs)):
width = bar.get_width()
ax.text(width + 0.01, bar.get_y() + bar.get_height()/2,
f'{prob:.2%}', va='center', fontsize=9)
# Set x-axis limit and style
ax.set_xlim(0, max(probs) * 1.15) # Add some padding for text
ax.grid(axis='x', alpha=0.3, linestyle='--')
plt.tight_layout()
return fig