Upload inference.py
Browse files- inference.py +178 -0
inference.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import SwinForImageClassification, AutoImageProcessor
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import joblib
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import cv2
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
class CoinPredictor:
|
| 12 |
+
def __init__(self, model_dir='model_checkpoints', top_n=10):
|
| 13 |
+
"""
|
| 14 |
+
Initialize the predictor with trained model and necessary components.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
model_dir (str): Directory containing the saved model and label encoder
|
| 18 |
+
top_n (int): Number of top predictions to return
|
| 19 |
+
"""
|
| 20 |
+
self.top_n = top_n
|
| 21 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
self.crop_percentage = 0.15 # 15% crop from each side
|
| 23 |
+
|
| 24 |
+
# Load the image processor
|
| 25 |
+
self.image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
|
| 26 |
+
|
| 27 |
+
# Load the trained model
|
| 28 |
+
model_path = os.path.join(model_dir, 'best_model')
|
| 29 |
+
self.model = SwinForImageClassification.from_pretrained(model_path)
|
| 30 |
+
self.model.to(self.device)
|
| 31 |
+
self.model.eval()
|
| 32 |
+
|
| 33 |
+
# Load the label encoder
|
| 34 |
+
encoder_path = os.path.join(model_dir, 'label_encoder.joblib')
|
| 35 |
+
self.label_encoder = joblib.load(encoder_path)
|
| 36 |
+
|
| 37 |
+
print(f"Model loaded and running on {self.device}")
|
| 38 |
+
|
| 39 |
+
def crop_center(self, image):
|
| 40 |
+
"""
|
| 41 |
+
Crop the center portion of the image.
|
| 42 |
+
"""
|
| 43 |
+
h, w = image.shape[:2]
|
| 44 |
+
crop_h = int(h * self.crop_percentage)
|
| 45 |
+
crop_w = int(w * self.crop_percentage)
|
| 46 |
+
|
| 47 |
+
return image[crop_h:h-crop_h, crop_w:w-crop_w]
|
| 48 |
+
|
| 49 |
+
def preprocess_image(self, image_path):
|
| 50 |
+
"""
|
| 51 |
+
Preprocess a single image for prediction.
|
| 52 |
+
"""
|
| 53 |
+
# Read image
|
| 54 |
+
image = cv2.imread(image_path)
|
| 55 |
+
if image is None:
|
| 56 |
+
raise ValueError(f"Could not load image: {image_path}")
|
| 57 |
+
|
| 58 |
+
# Crop center
|
| 59 |
+
image = self.crop_center(image)
|
| 60 |
+
|
| 61 |
+
# Convert to RGB (from BGR)
|
| 62 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 63 |
+
|
| 64 |
+
# Convert to PIL Image
|
| 65 |
+
image = Image.fromarray(image)
|
| 66 |
+
|
| 67 |
+
return image
|
| 68 |
+
|
| 69 |
+
def predict(self, image_path):
|
| 70 |
+
"""
|
| 71 |
+
Make prediction for a single image.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
image_path (str): Path to the image file
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
list of tuples: (label, probability) for top N predictions
|
| 78 |
+
"""
|
| 79 |
+
# Preprocess image
|
| 80 |
+
image = self.preprocess_image(image_path)
|
| 81 |
+
|
| 82 |
+
# Prepare image for model
|
| 83 |
+
inputs = self.image_processor(image, return_tensors="pt")
|
| 84 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 85 |
+
|
| 86 |
+
# Get predictions
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
outputs = self.model(**inputs)
|
| 89 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 90 |
+
|
| 91 |
+
# Get top N predictions
|
| 92 |
+
top_probs, top_indices = torch.topk(probabilities[0], self.top_n)
|
| 93 |
+
|
| 94 |
+
# Convert to labels and probabilities
|
| 95 |
+
predictions = []
|
| 96 |
+
for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()):
|
| 97 |
+
label = self.label_encoder.inverse_transform([idx])[0]
|
| 98 |
+
predictions.append((label, float(prob)))
|
| 99 |
+
|
| 100 |
+
return predictions, image
|
| 101 |
+
|
| 102 |
+
def visualize_prediction(self, image_path, predictions, reference_dir="all_coins_cropped"):
|
| 103 |
+
"""
|
| 104 |
+
Visualize the input image and top N matching reference images with probabilities.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
image_path (str): Path to the query image
|
| 108 |
+
predictions (tuple): (predictions, preprocessed_image)
|
| 109 |
+
reference_dir (str): Directory containing reference images
|
| 110 |
+
"""
|
| 111 |
+
predictions, processed_image = predictions
|
| 112 |
+
|
| 113 |
+
# Calculate grid size
|
| 114 |
+
n_cols = 4 # 4 images per row
|
| 115 |
+
n_rows = (self.top_n + 3) // n_cols # +3 for query images and ceiling division
|
| 116 |
+
|
| 117 |
+
# Create figure
|
| 118 |
+
fig = plt.figure(figsize=(15, 4 * n_rows))
|
| 119 |
+
|
| 120 |
+
# Plot original query image
|
| 121 |
+
plt.subplot(n_rows, n_cols, 1)
|
| 122 |
+
original_img = cv2.imread(image_path)
|
| 123 |
+
original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
|
| 124 |
+
plt.imshow(original_img)
|
| 125 |
+
plt.title("Original Query")
|
| 126 |
+
plt.axis('off')
|
| 127 |
+
|
| 128 |
+
# Plot processed query image
|
| 129 |
+
plt.subplot(n_rows, n_cols, 2)
|
| 130 |
+
plt.imshow(processed_image)
|
| 131 |
+
plt.title("Processed Query")
|
| 132 |
+
plt.axis('off')
|
| 133 |
+
|
| 134 |
+
# Plot top N predictions with their reference images
|
| 135 |
+
for i, (label, prob) in enumerate(predictions, 3):
|
| 136 |
+
# Find reference image
|
| 137 |
+
ref_path = None
|
| 138 |
+
for ext in ['.jpg', '.jpeg', '.png']:
|
| 139 |
+
test_path = os.path.join(reference_dir, label + ext)
|
| 140 |
+
if os.path.exists(test_path):
|
| 141 |
+
ref_path = test_path
|
| 142 |
+
break
|
| 143 |
+
|
| 144 |
+
if ref_path:
|
| 145 |
+
plt.subplot(n_rows, n_cols, i)
|
| 146 |
+
ref_img = cv2.imread(ref_path)
|
| 147 |
+
ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
|
| 148 |
+
plt.imshow(ref_img)
|
| 149 |
+
plt.title(f"{label}\n{prob:.1%}")
|
| 150 |
+
plt.axis('off')
|
| 151 |
+
|
| 152 |
+
plt.tight_layout()
|
| 153 |
+
plt.show()
|
| 154 |
+
|
| 155 |
+
def main():
|
| 156 |
+
# Initialize predictor
|
| 157 |
+
predictor = CoinPredictor()
|
| 158 |
+
|
| 159 |
+
# Get image path from user
|
| 160 |
+
image_path = input("Enter the path to the coin image: ")
|
| 161 |
+
|
| 162 |
+
try:
|
| 163 |
+
# Make prediction
|
| 164 |
+
predictions = predictor.predict(image_path)
|
| 165 |
+
|
| 166 |
+
# Print results
|
| 167 |
+
print("\nPredictions:")
|
| 168 |
+
for i, (label, prob) in enumerate(predictions[0], 1):
|
| 169 |
+
print(f"{i}. {label}: {prob:.1%}")
|
| 170 |
+
|
| 171 |
+
# Visualize results
|
| 172 |
+
predictor.visualize_prediction(image_path, predictions)
|
| 173 |
+
|
| 174 |
+
except Exception as e:
|
| 175 |
+
print(f"Error: {e}")
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
main()
|