OlejnikM commited on
Commit
c179902
·
verified ·
1 Parent(s): d7f77ee

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +178 -0
inference.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import SwinForImageClassification, AutoImageProcessor
3
+ from PIL import Image
4
+ import joblib
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import cv2
8
+ import os
9
+ from pathlib import Path
10
+
11
+ class CoinPredictor:
12
+ def __init__(self, model_dir='model_checkpoints', top_n=10):
13
+ """
14
+ Initialize the predictor with trained model and necessary components.
15
+
16
+ Args:
17
+ model_dir (str): Directory containing the saved model and label encoder
18
+ top_n (int): Number of top predictions to return
19
+ """
20
+ self.top_n = top_n
21
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ self.crop_percentage = 0.15 # 15% crop from each side
23
+
24
+ # Load the image processor
25
+ self.image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
26
+
27
+ # Load the trained model
28
+ model_path = os.path.join(model_dir, 'best_model')
29
+ self.model = SwinForImageClassification.from_pretrained(model_path)
30
+ self.model.to(self.device)
31
+ self.model.eval()
32
+
33
+ # Load the label encoder
34
+ encoder_path = os.path.join(model_dir, 'label_encoder.joblib')
35
+ self.label_encoder = joblib.load(encoder_path)
36
+
37
+ print(f"Model loaded and running on {self.device}")
38
+
39
+ def crop_center(self, image):
40
+ """
41
+ Crop the center portion of the image.
42
+ """
43
+ h, w = image.shape[:2]
44
+ crop_h = int(h * self.crop_percentage)
45
+ crop_w = int(w * self.crop_percentage)
46
+
47
+ return image[crop_h:h-crop_h, crop_w:w-crop_w]
48
+
49
+ def preprocess_image(self, image_path):
50
+ """
51
+ Preprocess a single image for prediction.
52
+ """
53
+ # Read image
54
+ image = cv2.imread(image_path)
55
+ if image is None:
56
+ raise ValueError(f"Could not load image: {image_path}")
57
+
58
+ # Crop center
59
+ image = self.crop_center(image)
60
+
61
+ # Convert to RGB (from BGR)
62
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
63
+
64
+ # Convert to PIL Image
65
+ image = Image.fromarray(image)
66
+
67
+ return image
68
+
69
+ def predict(self, image_path):
70
+ """
71
+ Make prediction for a single image.
72
+
73
+ Args:
74
+ image_path (str): Path to the image file
75
+
76
+ Returns:
77
+ list of tuples: (label, probability) for top N predictions
78
+ """
79
+ # Preprocess image
80
+ image = self.preprocess_image(image_path)
81
+
82
+ # Prepare image for model
83
+ inputs = self.image_processor(image, return_tensors="pt")
84
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
85
+
86
+ # Get predictions
87
+ with torch.no_grad():
88
+ outputs = self.model(**inputs)
89
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
90
+
91
+ # Get top N predictions
92
+ top_probs, top_indices = torch.topk(probabilities[0], self.top_n)
93
+
94
+ # Convert to labels and probabilities
95
+ predictions = []
96
+ for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()):
97
+ label = self.label_encoder.inverse_transform([idx])[0]
98
+ predictions.append((label, float(prob)))
99
+
100
+ return predictions, image
101
+
102
+ def visualize_prediction(self, image_path, predictions, reference_dir="all_coins_cropped"):
103
+ """
104
+ Visualize the input image and top N matching reference images with probabilities.
105
+
106
+ Args:
107
+ image_path (str): Path to the query image
108
+ predictions (tuple): (predictions, preprocessed_image)
109
+ reference_dir (str): Directory containing reference images
110
+ """
111
+ predictions, processed_image = predictions
112
+
113
+ # Calculate grid size
114
+ n_cols = 4 # 4 images per row
115
+ n_rows = (self.top_n + 3) // n_cols # +3 for query images and ceiling division
116
+
117
+ # Create figure
118
+ fig = plt.figure(figsize=(15, 4 * n_rows))
119
+
120
+ # Plot original query image
121
+ plt.subplot(n_rows, n_cols, 1)
122
+ original_img = cv2.imread(image_path)
123
+ original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
124
+ plt.imshow(original_img)
125
+ plt.title("Original Query")
126
+ plt.axis('off')
127
+
128
+ # Plot processed query image
129
+ plt.subplot(n_rows, n_cols, 2)
130
+ plt.imshow(processed_image)
131
+ plt.title("Processed Query")
132
+ plt.axis('off')
133
+
134
+ # Plot top N predictions with their reference images
135
+ for i, (label, prob) in enumerate(predictions, 3):
136
+ # Find reference image
137
+ ref_path = None
138
+ for ext in ['.jpg', '.jpeg', '.png']:
139
+ test_path = os.path.join(reference_dir, label + ext)
140
+ if os.path.exists(test_path):
141
+ ref_path = test_path
142
+ break
143
+
144
+ if ref_path:
145
+ plt.subplot(n_rows, n_cols, i)
146
+ ref_img = cv2.imread(ref_path)
147
+ ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
148
+ plt.imshow(ref_img)
149
+ plt.title(f"{label}\n{prob:.1%}")
150
+ plt.axis('off')
151
+
152
+ plt.tight_layout()
153
+ plt.show()
154
+
155
+ def main():
156
+ # Initialize predictor
157
+ predictor = CoinPredictor()
158
+
159
+ # Get image path from user
160
+ image_path = input("Enter the path to the coin image: ")
161
+
162
+ try:
163
+ # Make prediction
164
+ predictions = predictor.predict(image_path)
165
+
166
+ # Print results
167
+ print("\nPredictions:")
168
+ for i, (label, prob) in enumerate(predictions[0], 1):
169
+ print(f"{i}. {label}: {prob:.1%}")
170
+
171
+ # Visualize results
172
+ predictor.visualize_prediction(image_path, predictions)
173
+
174
+ except Exception as e:
175
+ print(f"Error: {e}")
176
+
177
+ if __name__ == "__main__":
178
+ main()