twoEuroClassifier / inference.py
OlejnikM's picture
Upload inference.py
c179902 verified
import torch
from transformers import SwinForImageClassification, AutoImageProcessor
from PIL import Image
import joblib
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
from pathlib import Path
class CoinPredictor:
def __init__(self, model_dir='model_checkpoints', top_n=10):
"""
Initialize the predictor with trained model and necessary components.
Args:
model_dir (str): Directory containing the saved model and label encoder
top_n (int): Number of top predictions to return
"""
self.top_n = top_n
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.crop_percentage = 0.15 # 15% crop from each side
# Load the image processor
self.image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
# Load the trained model
model_path = os.path.join(model_dir, 'best_model')
self.model = SwinForImageClassification.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
# Load the label encoder
encoder_path = os.path.join(model_dir, 'label_encoder.joblib')
self.label_encoder = joblib.load(encoder_path)
print(f"Model loaded and running on {self.device}")
def crop_center(self, image):
"""
Crop the center portion of the image.
"""
h, w = image.shape[:2]
crop_h = int(h * self.crop_percentage)
crop_w = int(w * self.crop_percentage)
return image[crop_h:h-crop_h, crop_w:w-crop_w]
def preprocess_image(self, image_path):
"""
Preprocess a single image for prediction.
"""
# Read image
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Could not load image: {image_path}")
# Crop center
image = self.crop_center(image)
# Convert to RGB (from BGR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Convert to PIL Image
image = Image.fromarray(image)
return image
def predict(self, image_path):
"""
Make prediction for a single image.
Args:
image_path (str): Path to the image file
Returns:
list of tuples: (label, probability) for top N predictions
"""
# Preprocess image
image = self.preprocess_image(image_path)
# Prepare image for model
inputs = self.image_processor(image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get predictions
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get top N predictions
top_probs, top_indices = torch.topk(probabilities[0], self.top_n)
# Convert to labels and probabilities
predictions = []
for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()):
label = self.label_encoder.inverse_transform([idx])[0]
predictions.append((label, float(prob)))
return predictions, image
def visualize_prediction(self, image_path, predictions, reference_dir="all_coins_cropped"):
"""
Visualize the input image and top N matching reference images with probabilities.
Args:
image_path (str): Path to the query image
predictions (tuple): (predictions, preprocessed_image)
reference_dir (str): Directory containing reference images
"""
predictions, processed_image = predictions
# Calculate grid size
n_cols = 4 # 4 images per row
n_rows = (self.top_n + 3) // n_cols # +3 for query images and ceiling division
# Create figure
fig = plt.figure(figsize=(15, 4 * n_rows))
# Plot original query image
plt.subplot(n_rows, n_cols, 1)
original_img = cv2.imread(image_path)
original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
plt.imshow(original_img)
plt.title("Original Query")
plt.axis('off')
# Plot processed query image
plt.subplot(n_rows, n_cols, 2)
plt.imshow(processed_image)
plt.title("Processed Query")
plt.axis('off')
# Plot top N predictions with their reference images
for i, (label, prob) in enumerate(predictions, 3):
# Find reference image
ref_path = None
for ext in ['.jpg', '.jpeg', '.png']:
test_path = os.path.join(reference_dir, label + ext)
if os.path.exists(test_path):
ref_path = test_path
break
if ref_path:
plt.subplot(n_rows, n_cols, i)
ref_img = cv2.imread(ref_path)
ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
plt.imshow(ref_img)
plt.title(f"{label}\n{prob:.1%}")
plt.axis('off')
plt.tight_layout()
plt.show()
def main():
# Initialize predictor
predictor = CoinPredictor()
# Get image path from user
image_path = input("Enter the path to the coin image: ")
try:
# Make prediction
predictions = predictor.predict(image_path)
# Print results
print("\nPredictions:")
for i, (label, prob) in enumerate(predictions[0], 1):
print(f"{i}. {label}: {prob:.1%}")
# Visualize results
predictor.visualize_prediction(image_path, predictions)
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()