Spaces:
Sleeping
Sleeping
| # src/predictor.py | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| def predict_image(image, model, processor, top_k=5): | |
| """ | |
| Perform inference on an image and return top-k predictions. | |
| Args: | |
| image (PIL.Image): Input image to classify. | |
| model: Loaded ViT model. | |
| processor: Loaded ViT processor. | |
| top_k (int): Number of top predictions to return. | |
| Returns: | |
| tuple: (top_probs, top_indices, top_labels) - Probabilities, class indices, and label names. | |
| """ | |
| try: | |
| # Get the device from the model | |
| device = next(model.parameters()).device | |
| # Preprocess the image - note: current processors return pixel_values | |
| inputs = processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Perform inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Apply softmax to get probabilities | |
| probabilities = F.softmax(logits, dim=-1)[0] | |
| # Get top-k predictions | |
| top_probs, top_indices = torch.topk(probabilities, top_k) | |
| # Convert to Python lists and numpy arrays | |
| top_probs = top_probs.cpu().numpy() | |
| top_indices = top_indices.cpu().numpy() | |
| # Get human-readable labels | |
| top_labels = [model.config.id2label[idx] for idx in top_indices] | |
| return top_probs, top_indices, top_labels | |
| except Exception as e: | |
| print(f"Error during prediction: {str(e)}") | |
| raise | |
| def create_prediction_plot(probs, labels): | |
| """ | |
| Create a clean, professional bar chart for top predictions. | |
| Args: | |
| probs (np.array): Array of probabilities. | |
| labels (list): List of label names. | |
| Returns: | |
| matplotlib.figure.Figure: The generated plot figure. | |
| """ | |
| fig, ax = plt.subplots(figsize=(8, 4)) | |
| # Create horizontal bar chart | |
| y_pos = np.arange(len(labels)) | |
| bars = ax.barh(y_pos, probs, color='skyblue', alpha=0.8) | |
| ax.set_yticks(y_pos) | |
| ax.set_yticklabels(labels, fontsize=10) | |
| ax.set_xlabel('Confidence', fontsize=12) | |
| ax.set_title('Top Predictions', fontsize=14, fontweight='bold') | |
| # Add probability text on bars | |
| for i, (bar, prob) in enumerate(zip(bars, probs)): | |
| width = bar.get_width() | |
| ax.text(width + 0.01, bar.get_y() + bar.get_height()/2, | |
| f'{prob:.2%}', va='center', fontsize=9) | |
| # Set x-axis limit and style | |
| ax.set_xlim(0, max(probs) * 1.15) # Add some padding for text | |
| ax.grid(axis='x', alpha=0.3, linestyle='--') | |
| plt.tight_layout() | |
| return fig |