File size: 6,214 Bytes
c179902 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import torch
from transformers import SwinForImageClassification, AutoImageProcessor
from PIL import Image
import joblib
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
from pathlib import Path
class CoinPredictor:
def __init__(self, model_dir='model_checkpoints', top_n=10):
"""
Initialize the predictor with trained model and necessary components.
Args:
model_dir (str): Directory containing the saved model and label encoder
top_n (int): Number of top predictions to return
"""
self.top_n = top_n
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.crop_percentage = 0.15 # 15% crop from each side
# Load the image processor
self.image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
# Load the trained model
model_path = os.path.join(model_dir, 'best_model')
self.model = SwinForImageClassification.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
# Load the label encoder
encoder_path = os.path.join(model_dir, 'label_encoder.joblib')
self.label_encoder = joblib.load(encoder_path)
print(f"Model loaded and running on {self.device}")
def crop_center(self, image):
"""
Crop the center portion of the image.
"""
h, w = image.shape[:2]
crop_h = int(h * self.crop_percentage)
crop_w = int(w * self.crop_percentage)
return image[crop_h:h-crop_h, crop_w:w-crop_w]
def preprocess_image(self, image_path):
"""
Preprocess a single image for prediction.
"""
# Read image
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Could not load image: {image_path}")
# Crop center
image = self.crop_center(image)
# Convert to RGB (from BGR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Convert to PIL Image
image = Image.fromarray(image)
return image
def predict(self, image_path):
"""
Make prediction for a single image.
Args:
image_path (str): Path to the image file
Returns:
list of tuples: (label, probability) for top N predictions
"""
# Preprocess image
image = self.preprocess_image(image_path)
# Prepare image for model
inputs = self.image_processor(image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get predictions
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get top N predictions
top_probs, top_indices = torch.topk(probabilities[0], self.top_n)
# Convert to labels and probabilities
predictions = []
for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()):
label = self.label_encoder.inverse_transform([idx])[0]
predictions.append((label, float(prob)))
return predictions, image
def visualize_prediction(self, image_path, predictions, reference_dir="all_coins_cropped"):
"""
Visualize the input image and top N matching reference images with probabilities.
Args:
image_path (str): Path to the query image
predictions (tuple): (predictions, preprocessed_image)
reference_dir (str): Directory containing reference images
"""
predictions, processed_image = predictions
# Calculate grid size
n_cols = 4 # 4 images per row
n_rows = (self.top_n + 3) // n_cols # +3 for query images and ceiling division
# Create figure
fig = plt.figure(figsize=(15, 4 * n_rows))
# Plot original query image
plt.subplot(n_rows, n_cols, 1)
original_img = cv2.imread(image_path)
original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
plt.imshow(original_img)
plt.title("Original Query")
plt.axis('off')
# Plot processed query image
plt.subplot(n_rows, n_cols, 2)
plt.imshow(processed_image)
plt.title("Processed Query")
plt.axis('off')
# Plot top N predictions with their reference images
for i, (label, prob) in enumerate(predictions, 3):
# Find reference image
ref_path = None
for ext in ['.jpg', '.jpeg', '.png']:
test_path = os.path.join(reference_dir, label + ext)
if os.path.exists(test_path):
ref_path = test_path
break
if ref_path:
plt.subplot(n_rows, n_cols, i)
ref_img = cv2.imread(ref_path)
ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
plt.imshow(ref_img)
plt.title(f"{label}\n{prob:.1%}")
plt.axis('off')
plt.tight_layout()
plt.show()
def main():
# Initialize predictor
predictor = CoinPredictor()
# Get image path from user
image_path = input("Enter the path to the coin image: ")
try:
# Make prediction
predictions = predictor.predict(image_path)
# Print results
print("\nPredictions:")
for i, (label, prob) in enumerate(predictions[0], 1):
print(f"{i}. {label}: {prob:.1%}")
# Visualize results
predictor.visualize_prediction(image_path, predictions)
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main() |