YYama0 commited on
Commit
1fcfe94
·
verified ·
1 Parent(s): afc86cd

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +282 -0
inference.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RadFig VQA Image Filtering Model - Inference Script
3
+ Classifies medical images as suitable/unsuitable for VQA tasks.
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import timm
10
+ import cv2
11
+ import numpy as np
12
+ import pandas as pd
13
+ from PIL import Image
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from albumentations import Compose, Resize, Normalize
16
+ from albumentations.pytorch import ToTensorV2
17
+ from tqdm import tqdm
18
+
19
+
20
+ class Config:
21
+ """Configuration for inference"""
22
+ model_name = "tf_efficientnetv2_s.in21k_ft_in1k"
23
+ size = 512
24
+ batch_size = 32
25
+ num_workers = 4
26
+ target_size = 1
27
+ n_fold = 5
28
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
+
30
+
31
+ class TestDataset(Dataset):
32
+ """Dataset for inference"""
33
+
34
+ def __init__(self, image_paths, transform=None):
35
+ self.image_paths = image_paths
36
+ self.transform = transform
37
+
38
+ def __len__(self):
39
+ return len(self.image_paths)
40
+
41
+ def __getitem__(self, idx):
42
+ image_path = self.image_paths[idx]
43
+
44
+ # Load image
45
+ image = cv2.imread(image_path)
46
+ if image is None:
47
+ raise ValueError(f"Could not load image: {image_path}")
48
+
49
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50
+
51
+ if self.transform:
52
+ augmented = self.transform(image=image)
53
+ image = augmented['image']
54
+
55
+ return image
56
+
57
+
58
+ def get_transforms():
59
+ """Get inference transforms"""
60
+ return Compose([
61
+ Resize(Config.size, Config.size),
62
+ Normalize(
63
+ mean=[0.485, 0.456, 0.406],
64
+ std=[0.229, 0.224, 0.225],
65
+ ),
66
+ ToTensorV2(),
67
+ ])
68
+
69
+
70
+ class RadFigClassifier:
71
+ """RadFig VQA Image Filtering Classifier"""
72
+
73
+ def __init__(self, model_dir="models"):
74
+ self.config = Config()
75
+ self.model_dir = model_dir
76
+ self.device = self.config.device
77
+ self.model = None
78
+ self.states = []
79
+
80
+ # Load model states
81
+ self._load_model_states()
82
+
83
+ def _load_model_states(self):
84
+ """Load all fold model states"""
85
+ self.states = []
86
+ for fold in range(self.config.n_fold):
87
+ model_path = os.path.join(
88
+ self.model_dir,
89
+ f"{self.config.model_name}_fold{fold}_best_loss.pth"
90
+ )
91
+
92
+ if not os.path.exists(model_path):
93
+ raise FileNotFoundError(f"Model file not found: {model_path}")
94
+
95
+ state = torch.load(model_path, map_location=self.device)
96
+ self.states.append(state)
97
+
98
+ print(f"Loaded {len(self.states)} model states from {self.model_dir}")
99
+
100
+ def _create_model(self):
101
+ """Create model architecture"""
102
+ model = timm.create_model(
103
+ model_name=self.config.model_name,
104
+ num_classes=self.config.target_size,
105
+ pretrained=False
106
+ )
107
+ return model.to(self.device)
108
+
109
+ def predict_batch(self, image_paths, return_probabilities=True):
110
+ """
111
+ Predict on a batch of images
112
+
113
+ Args:
114
+ image_paths (list): List of image file paths
115
+ return_probabilities (bool): If True, return probabilities. If False, return binary predictions.
116
+
117
+ Returns:
118
+ numpy.ndarray: Predictions (probabilities or binary)
119
+ """
120
+ # Create dataset and dataloader
121
+ dataset = TestDataset(image_paths, transform=get_transforms())
122
+ dataloader = DataLoader(
123
+ dataset,
124
+ batch_size=self.config.batch_size,
125
+ shuffle=False,
126
+ num_workers=self.config.num_workers,
127
+ pin_memory=True
128
+ )
129
+
130
+ # Create model
131
+ model = self._create_model()
132
+
133
+ all_predictions = []
134
+
135
+ # Inference loop
136
+ with torch.no_grad():
137
+ for images in tqdm(dataloader, desc="Predicting"):
138
+ images = images.to(self.device)
139
+
140
+ # Ensemble predictions across all folds
141
+ fold_predictions = []
142
+
143
+ for state in self.states:
144
+ model.load_state_dict(state['model'])
145
+ model.eval()
146
+
147
+ outputs = model(images)
148
+ probabilities = torch.sigmoid(outputs).cpu().numpy()
149
+ fold_predictions.append(probabilities)
150
+
151
+ # Average predictions across folds
152
+ avg_predictions = np.mean(fold_predictions, axis=0)
153
+ all_predictions.append(avg_predictions)
154
+
155
+ # Concatenate all predictions
156
+ predictions = np.concatenate(all_predictions, axis=0).flatten()
157
+
158
+ if return_probabilities:
159
+ return predictions
160
+ else:
161
+ return (predictions > 0.5).astype(int)
162
+
163
+ def predict_single(self, image_path, return_probability=True):
164
+ """
165
+ Predict on a single image
166
+
167
+ Args:
168
+ image_path (str): Path to image file
169
+ return_probability (bool): If True, return probability. If False, return binary prediction.
170
+
171
+ Returns:
172
+ float or int: Prediction
173
+ """
174
+ predictions = self.predict_batch([image_path], return_probabilities=return_probability)
175
+ return predictions[0]
176
+
177
+ def predict_directory(self, directory_path, output_csv=None, return_probabilities=True):
178
+ """
179
+ Predict on all images in a directory
180
+
181
+ Args:
182
+ directory_path (str): Path to directory containing images
183
+ output_csv (str, optional): Path to save results as CSV
184
+ return_probabilities (bool): If True, return probabilities. If False, return binary predictions.
185
+
186
+ Returns:
187
+ pandas.DataFrame: Results with image paths and predictions
188
+ """
189
+ # Get all image files
190
+ image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
191
+ image_paths = []
192
+
193
+ for filename in os.listdir(directory_path):
194
+ if any(filename.lower().endswith(ext) for ext in image_extensions):
195
+ image_paths.append(os.path.join(directory_path, filename))
196
+
197
+ if not image_paths:
198
+ raise ValueError(f"No image files found in {directory_path}")
199
+
200
+ print(f"Found {len(image_paths)} images in {directory_path}")
201
+
202
+ # Get predictions
203
+ predictions = self.predict_batch(image_paths, return_probabilities=return_probabilities)
204
+
205
+ # Create results dataframe
206
+ results = pd.DataFrame({
207
+ 'image_path': image_paths,
208
+ 'filename': [os.path.basename(path) for path in image_paths],
209
+ 'prediction': predictions,
210
+ 'suitable_for_vqa': predictions > 0.5 if return_probabilities else predictions.astype(bool)
211
+ })
212
+
213
+ # Sort by filename for consistency
214
+ results = results.sort_values('filename').reset_index(drop=True)
215
+
216
+ # Save to CSV if requested
217
+ if output_csv:
218
+ results.to_csv(output_csv, index=False)
219
+ print(f"Results saved to {output_csv}")
220
+
221
+ return results
222
+
223
+
224
+ def main():
225
+ """Example usage"""
226
+ import argparse
227
+
228
+ parser = argparse.ArgumentParser(description="RadFig VQA Image Filtering Inference")
229
+ parser.add_argument("--input", required=True, help="Input image file or directory")
230
+ parser.add_argument("--models", default="models", help="Directory containing model files")
231
+ parser.add_argument("--output", help="Output CSV file (for directory input)")
232
+ parser.add_argument("--binary", action="store_true", help="Return binary predictions instead of probabilities")
233
+
234
+ args = parser.parse_args()
235
+
236
+ # Initialize classifier
237
+ classifier = RadFigClassifier(model_dir=args.models)
238
+
239
+ if os.path.isfile(args.input):
240
+ # Single image prediction
241
+ prediction = classifier.predict_single(
242
+ args.input,
243
+ return_probability=not args.binary
244
+ )
245
+
246
+ if args.binary:
247
+ result = "suitable" if prediction else "not suitable"
248
+ print(f"Image: {args.input}")
249
+ print(f"Prediction: {result} for VQA")
250
+ else:
251
+ print(f"Image: {args.input}")
252
+ print(f"Probability suitable for VQA: {prediction:.4f}")
253
+ print(f"Classification: {'suitable' if prediction > 0.5 else 'not suitable'}")
254
+
255
+ elif os.path.isdir(args.input):
256
+ # Directory prediction
257
+ results = classifier.predict_directory(
258
+ args.input,
259
+ output_csv=args.output,
260
+ return_probabilities=not args.binary
261
+ )
262
+
263
+ # Print summary
264
+ if args.binary:
265
+ suitable_count = results['suitable_for_vqa'].sum()
266
+ else:
267
+ suitable_count = (results['prediction'] > 0.5).sum()
268
+
269
+ total_count = len(results)
270
+
271
+ print(f"\nSummary:")
272
+ print(f"Total images: {total_count}")
273
+ print(f"Suitable for VQA: {suitable_count}")
274
+ print(f"Not suitable for VQA: {total_count - suitable_count}")
275
+ print(f"Percentage suitable: {suitable_count/total_count*100:.1f}%")
276
+
277
+ else:
278
+ print(f"Error: {args.input} is not a valid file or directory")
279
+
280
+
281
+ if __name__ == "__main__":
282
+ main()