import torch from transformers import SwinForImageClassification, AutoImageProcessor from PIL import Image import joblib import numpy as np import matplotlib.pyplot as plt import cv2 import os from pathlib import Path class CoinPredictor: def __init__(self, model_dir='model_checkpoints', top_n=10): """ Initialize the predictor with trained model and necessary components. Args: model_dir (str): Directory containing the saved model and label encoder top_n (int): Number of top predictions to return """ self.top_n = top_n self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.crop_percentage = 0.15 # 15% crop from each side # Load the image processor self.image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224") # Load the trained model model_path = os.path.join(model_dir, 'best_model') self.model = SwinForImageClassification.from_pretrained(model_path) self.model.to(self.device) self.model.eval() # Load the label encoder encoder_path = os.path.join(model_dir, 'label_encoder.joblib') self.label_encoder = joblib.load(encoder_path) print(f"Model loaded and running on {self.device}") def crop_center(self, image): """ Crop the center portion of the image. """ h, w = image.shape[:2] crop_h = int(h * self.crop_percentage) crop_w = int(w * self.crop_percentage) return image[crop_h:h-crop_h, crop_w:w-crop_w] def preprocess_image(self, image_path): """ Preprocess a single image for prediction. """ # Read image image = cv2.imread(image_path) if image is None: raise ValueError(f"Could not load image: {image_path}") # Crop center image = self.crop_center(image) # Convert to RGB (from BGR) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert to PIL Image image = Image.fromarray(image) return image def predict(self, image_path): """ Make prediction for a single image. Args: image_path (str): Path to the image file Returns: list of tuples: (label, probability) for top N predictions """ # Preprocess image image = self.preprocess_image(image_path) # Prepare image for model inputs = self.image_processor(image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} # Get predictions with torch.no_grad(): outputs = self.model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) # Get top N predictions top_probs, top_indices = torch.topk(probabilities[0], self.top_n) # Convert to labels and probabilities predictions = [] for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()): label = self.label_encoder.inverse_transform([idx])[0] predictions.append((label, float(prob))) return predictions, image def visualize_prediction(self, image_path, predictions, reference_dir="all_coins_cropped"): """ Visualize the input image and top N matching reference images with probabilities. Args: image_path (str): Path to the query image predictions (tuple): (predictions, preprocessed_image) reference_dir (str): Directory containing reference images """ predictions, processed_image = predictions # Calculate grid size n_cols = 4 # 4 images per row n_rows = (self.top_n + 3) // n_cols # +3 for query images and ceiling division # Create figure fig = plt.figure(figsize=(15, 4 * n_rows)) # Plot original query image plt.subplot(n_rows, n_cols, 1) original_img = cv2.imread(image_path) original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB) plt.imshow(original_img) plt.title("Original Query") plt.axis('off') # Plot processed query image plt.subplot(n_rows, n_cols, 2) plt.imshow(processed_image) plt.title("Processed Query") plt.axis('off') # Plot top N predictions with their reference images for i, (label, prob) in enumerate(predictions, 3): # Find reference image ref_path = None for ext in ['.jpg', '.jpeg', '.png']: test_path = os.path.join(reference_dir, label + ext) if os.path.exists(test_path): ref_path = test_path break if ref_path: plt.subplot(n_rows, n_cols, i) ref_img = cv2.imread(ref_path) ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB) plt.imshow(ref_img) plt.title(f"{label}\n{prob:.1%}") plt.axis('off') plt.tight_layout() plt.show() def main(): # Initialize predictor predictor = CoinPredictor() # Get image path from user image_path = input("Enter the path to the coin image: ") try: # Make prediction predictions = predictor.predict(image_path) # Print results print("\nPredictions:") for i, (label, prob) in enumerate(predictions[0], 1): print(f"{i}. {label}: {prob:.1%}") # Visualize results predictor.visualize_prediction(image_path, predictions) except Exception as e: print(f"Error: {e}") if __name__ == "__main__": main()