|
|
import torch
|
|
|
from transformers import SwinForImageClassification, AutoImageProcessor
|
|
|
from PIL import Image
|
|
|
import joblib
|
|
|
import numpy as np
|
|
|
import matplotlib.pyplot as plt
|
|
|
import cv2
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
|
|
|
class CoinPredictor:
|
|
|
def __init__(self, model_dir='model_checkpoints', top_n=10):
|
|
|
"""
|
|
|
Initialize the predictor with trained model and necessary components.
|
|
|
|
|
|
Args:
|
|
|
model_dir (str): Directory containing the saved model and label encoder
|
|
|
top_n (int): Number of top predictions to return
|
|
|
"""
|
|
|
self.top_n = top_n
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
self.crop_percentage = 0.15
|
|
|
|
|
|
|
|
|
self.image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
|
|
|
|
|
|
|
|
|
model_path = os.path.join(model_dir, 'best_model')
|
|
|
self.model = SwinForImageClassification.from_pretrained(model_path)
|
|
|
self.model.to(self.device)
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
encoder_path = os.path.join(model_dir, 'label_encoder.joblib')
|
|
|
self.label_encoder = joblib.load(encoder_path)
|
|
|
|
|
|
print(f"Model loaded and running on {self.device}")
|
|
|
|
|
|
def crop_center(self, image):
|
|
|
"""
|
|
|
Crop the center portion of the image.
|
|
|
"""
|
|
|
h, w = image.shape[:2]
|
|
|
crop_h = int(h * self.crop_percentage)
|
|
|
crop_w = int(w * self.crop_percentage)
|
|
|
|
|
|
return image[crop_h:h-crop_h, crop_w:w-crop_w]
|
|
|
|
|
|
def preprocess_image(self, image_path):
|
|
|
"""
|
|
|
Preprocess a single image for prediction.
|
|
|
"""
|
|
|
|
|
|
image = cv2.imread(image_path)
|
|
|
if image is None:
|
|
|
raise ValueError(f"Could not load image: {image_path}")
|
|
|
|
|
|
|
|
|
image = self.crop_center(image)
|
|
|
|
|
|
|
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
|
|
|
|
|
|
image = Image.fromarray(image)
|
|
|
|
|
|
return image
|
|
|
|
|
|
def predict(self, image_path):
|
|
|
"""
|
|
|
Make prediction for a single image.
|
|
|
|
|
|
Args:
|
|
|
image_path (str): Path to the image file
|
|
|
|
|
|
Returns:
|
|
|
list of tuples: (label, probability) for top N predictions
|
|
|
"""
|
|
|
|
|
|
image = self.preprocess_image(image_path)
|
|
|
|
|
|
|
|
|
inputs = self.image_processor(image, return_tensors="pt")
|
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model(**inputs)
|
|
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
|
|
|
|
|
|
|
|
top_probs, top_indices = torch.topk(probabilities[0], self.top_n)
|
|
|
|
|
|
|
|
|
predictions = []
|
|
|
for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()):
|
|
|
label = self.label_encoder.inverse_transform([idx])[0]
|
|
|
predictions.append((label, float(prob)))
|
|
|
|
|
|
return predictions, image
|
|
|
|
|
|
def visualize_prediction(self, image_path, predictions, reference_dir="all_coins_cropped"):
|
|
|
"""
|
|
|
Visualize the input image and top N matching reference images with probabilities.
|
|
|
|
|
|
Args:
|
|
|
image_path (str): Path to the query image
|
|
|
predictions (tuple): (predictions, preprocessed_image)
|
|
|
reference_dir (str): Directory containing reference images
|
|
|
"""
|
|
|
predictions, processed_image = predictions
|
|
|
|
|
|
|
|
|
n_cols = 4
|
|
|
n_rows = (self.top_n + 3) // n_cols
|
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(15, 4 * n_rows))
|
|
|
|
|
|
|
|
|
plt.subplot(n_rows, n_cols, 1)
|
|
|
original_img = cv2.imread(image_path)
|
|
|
original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
|
|
|
plt.imshow(original_img)
|
|
|
plt.title("Original Query")
|
|
|
plt.axis('off')
|
|
|
|
|
|
|
|
|
plt.subplot(n_rows, n_cols, 2)
|
|
|
plt.imshow(processed_image)
|
|
|
plt.title("Processed Query")
|
|
|
plt.axis('off')
|
|
|
|
|
|
|
|
|
for i, (label, prob) in enumerate(predictions, 3):
|
|
|
|
|
|
ref_path = None
|
|
|
for ext in ['.jpg', '.jpeg', '.png']:
|
|
|
test_path = os.path.join(reference_dir, label + ext)
|
|
|
if os.path.exists(test_path):
|
|
|
ref_path = test_path
|
|
|
break
|
|
|
|
|
|
if ref_path:
|
|
|
plt.subplot(n_rows, n_cols, i)
|
|
|
ref_img = cv2.imread(ref_path)
|
|
|
ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
|
|
|
plt.imshow(ref_img)
|
|
|
plt.title(f"{label}\n{prob:.1%}")
|
|
|
plt.axis('off')
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.show()
|
|
|
|
|
|
def main():
|
|
|
|
|
|
predictor = CoinPredictor()
|
|
|
|
|
|
|
|
|
image_path = input("Enter the path to the coin image: ")
|
|
|
|
|
|
try:
|
|
|
|
|
|
predictions = predictor.predict(image_path)
|
|
|
|
|
|
|
|
|
print("\nPredictions:")
|
|
|
for i, (label, prob) in enumerate(predictions[0], 1):
|
|
|
print(f"{i}. {label}: {prob:.1%}")
|
|
|
|
|
|
|
|
|
predictor.visualize_prediction(image_path, predictions)
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Error: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |